From 64670e414eb4bcba7dcc382bb16be1492cdd6e95 Mon Sep 17 00:00:00 2001 From: Pritam Damania Date: Tue, 1 Feb 2022 22:53:18 -0800 Subject: [PATCH] [reland] Create torch.distributed._shard package. (#72141) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72141 We have many sharding components currently: torch.distributed._sharded_tensor, torch.distributed._sharding_spec, torch.distributed._sharded_optimizer and more coming. As a result, organizing all of this under the `torch.distributed._shard` package. For BC reasons, I'm still keeping the old packages and have them just reference the new package. ghstack-source-id: 148150861 ghstack-source-id: 148150861 Test Plan: waitforbuildbot Reviewed By: fduwjj Differential Revision: D33904585 fbshipit-source-id: 057e847eb7521b536a3ee4e0f94871aacc752062 (cherry picked from commit 29a70dd7afde6083bab942081020a13278f38e52) --- .../sharded_optim}/test_sharded_optim.py | 8 +- .../sharded_tensor}/ops/test_binary_cmp.py | 32 +- .../sharded_tensor}/ops/test_embedding.py | 6 +- .../sharded_tensor}/ops/test_embedding_bag.py | 6 +- .../sharded_tensor}/ops/test_init.py | 36 +- .../sharded_tensor}/ops/test_linear.py | 12 +- .../sharded_tensor}/test_sharded_tensor.py | 472 +++++++-------- .../sharding_spec}/test_sharding_spec.py | 4 +- test/run_test.py | 45 +- torch/distributed/_shard/__init__.py | 1 + torch/distributed/_shard/api.py | 145 +++++ .../sharded_optim}/__init__.py | 4 +- .../sharded_optim}/api.py | 2 +- .../_shard/sharded_tensor/__init__.py | 412 +++++++++++++ .../sharded_tensor/_ops}/__init__.py | 0 .../sharded_tensor/_ops}/_common.py | 2 +- .../sharded_tensor/_ops}/binary_cmp.py | 2 +- .../sharded_tensor/_ops}/embedding.py | 6 +- .../sharded_tensor/_ops}/embedding_bag.py | 6 +- .../sharded_tensor/_ops}/init.py | 2 +- .../sharded_tensor/_ops}/linear.py | 8 +- .../sharded_tensor}/api.py | 10 +- .../sharded_tensor}/metadata.py | 2 +- .../sharded_tensor}/shard.py | 4 +- .../sharded_tensor}/utils.py | 4 +- .../_shard/sharding_spec/__init__.py | 8 + .../sharding_spec}/_internals.py | 0 .../sharding_spec}/api.py | 0 torch/distributed/_sharded_tensor/__init__.py | 548 +----------------- torch/distributed/_sharding_spec/__init__.py | 18 +- .../_internal/distributed/_shard/__init__.py | 0 .../sharded_tensor}/__init__.py | 0 .../sharded_tensor}/_test_ops_common.py | 4 +- 33 files changed, 927 insertions(+), 882 deletions(-) rename test/distributed/{_sharded_optim => _shard/sharded_optim}/test_sharded_optim.py (96%) rename test/distributed/{_sharded_tensor => _shard/sharded_tensor}/ops/test_binary_cmp.py (81%) rename test/distributed/{_sharded_tensor => _shard/sharded_tensor}/ops/test_embedding.py (96%) rename test/distributed/{_sharded_tensor => _shard/sharded_tensor}/ops/test_embedding_bag.py (97%) rename test/distributed/{_sharded_tensor => _shard/sharded_tensor}/ops/test_init.py (70%) rename test/distributed/{_sharded_tensor => _shard/sharded_tensor}/ops/test_linear.py (96%) rename test/distributed/{_sharded_tensor => _shard/sharded_tensor}/test_sharded_tensor.py (81%) rename test/distributed/{_sharding_spec => _shard/sharding_spec}/test_sharding_spec.py (98%) create mode 100644 torch/distributed/_shard/__init__.py create mode 100644 torch/distributed/_shard/api.py rename torch/distributed/{_sharded_optim => _shard/sharded_optim}/__init__.py (93%) rename torch/distributed/{_sharded_optim => _shard/sharded_optim}/api.py (98%) create mode 100644 torch/distributed/_shard/sharded_tensor/__init__.py rename torch/distributed/{_sharded_tensor/ops => _shard/sharded_tensor/_ops}/__init__.py (100%) rename torch/distributed/{_sharded_tensor/ops => _shard/sharded_tensor/_ops}/_common.py (99%) rename torch/distributed/{_sharded_tensor/ops => _shard/sharded_tensor/_ops}/binary_cmp.py (97%) rename torch/distributed/{_sharded_tensor/ops => _shard/sharded_tensor/_ops}/embedding.py (98%) rename torch/distributed/{_sharded_tensor/ops => _shard/sharded_tensor/_ops}/embedding_bag.py (99%) rename torch/distributed/{_sharded_tensor/ops => _shard/sharded_tensor/_ops}/init.py (98%) rename torch/distributed/{_sharded_tensor/ops => _shard/sharded_tensor/_ops}/linear.py (97%) rename torch/distributed/{_sharded_tensor => _shard/sharded_tensor}/api.py (98%) rename torch/distributed/{_sharded_tensor => _shard/sharded_tensor}/metadata.py (97%) rename torch/distributed/{_sharded_tensor => _shard/sharded_tensor}/shard.py (93%) rename torch/distributed/{_sharded_tensor => _shard/sharded_tensor}/utils.py (98%) create mode 100644 torch/distributed/_shard/sharding_spec/__init__.py rename torch/distributed/{_sharding_spec => _shard/sharding_spec}/_internals.py (100%) rename torch/distributed/{_sharding_spec => _shard/sharding_spec}/api.py (100%) create mode 100644 torch/testing/_internal/distributed/_shard/__init__.py rename torch/testing/_internal/distributed/{_sharded_tensor => _shard/sharded_tensor}/__init__.py (100%) rename torch/testing/_internal/distributed/{_sharded_tensor => _shard/sharded_tensor}/_test_ops_common.py (94%) diff --git a/test/distributed/_sharded_optim/test_sharded_optim.py b/test/distributed/_shard/sharded_optim/test_sharded_optim.py similarity index 96% rename from test/distributed/_sharded_optim/test_sharded_optim.py rename to test/distributed/_shard/sharded_optim/test_sharded_optim.py index c72e458f2e9..085c928985e 100644 --- a/test/distributed/_sharded_optim/test_sharded_optim.py +++ b/test/distributed/_shard/sharded_optim/test_sharded_optim.py @@ -2,13 +2,13 @@ import torch import torch.optim as optim -import torch.distributed._sharded_tensor as sharded_tensor +import torch.distributed._shard.sharded_tensor from copy import deepcopy -from torch.distributed._sharding_spec import ( +from torch.distributed._shard.sharding_spec import ( ChunkShardingSpec, ) -from torch.distributed._sharded_optim import ( +from torch.distributed._shard.sharded_optim import ( ShardedOptimizer, named_params_with_sharded_tensor ) @@ -20,7 +20,7 @@ from torch.testing._internal.common_utils import ( run_tests, ) -from torch.testing._internal.distributed._sharded_tensor import ( +from torch.testing._internal.distributed._shard.sharded_tensor import ( ShardedTensorTestBase, with_comms, ) diff --git a/test/distributed/_sharded_tensor/ops/test_binary_cmp.py b/test/distributed/_shard/sharded_tensor/ops/test_binary_cmp.py similarity index 81% rename from test/distributed/_sharded_tensor/ops/test_binary_cmp.py rename to test/distributed/_shard/sharded_tensor/ops/test_binary_cmp.py index 362dd5fd9ff..c2072716952 100644 --- a/test/distributed/_sharded_tensor/ops/test_binary_cmp.py +++ b/test/distributed/_shard/sharded_tensor/ops/test_binary_cmp.py @@ -4,17 +4,17 @@ import sys import torch import torch.distributed as dist -from torch.distributed import _sharded_tensor +from torch.distributed._shard import sharded_tensor from torch.distributed.distributed_c10d import _get_default_group from torch.testing._internal.common_distributed import ( requires_nccl, skip_if_lt_x_gpu, ) -from torch.distributed._sharding_spec import ( +from torch.distributed._shard.sharding_spec import ( ChunkShardingSpec, ) -from torch.testing._internal.distributed._sharded_tensor import ( +from torch.testing._internal.distributed._shard.sharded_tensor import ( ShardedTensorTestBase, with_comms, ) @@ -35,9 +35,9 @@ class TestShardedTensorBinaryOps(ShardedTensorTestBase): pg1 = _get_default_group() if pg1 is None else pg1 pg2 = _get_default_group() if pg2 is None else pg2 torch.manual_seed(TestShardedTensorBinaryOps.seed) - st1 = _sharded_tensor.rand(spec1, sizes, process_group=pg1) + st1 = sharded_tensor.rand(spec1, sizes, process_group=pg1) torch.manual_seed(TestShardedTensorBinaryOps.seed + seed_offset) - st2 = _sharded_tensor.rand(spec2, sizes, process_group=pg2) + st2 = sharded_tensor.rand(spec2, sizes, process_group=pg2) TestShardedTensorBinaryOps.seed += 1 return st1, st2 @@ -72,23 +72,23 @@ class TestShardedTensorBinaryOps(ShardedTensorTestBase): torch.nn.init.uniform_(st1.local_shards()[0].tensor) self.assertFalse(cmp_op(st1, st2)) - st1 = _sharded_tensor.ones(spec, 10, 10) - st2 = _sharded_tensor.ones(spec, 10, 5) + st1 = sharded_tensor.ones(spec, 10, 10) + st2 = sharded_tensor.ones(spec, 10, 5) self.assertFalse(cmp_op(st1, st2)) st1, st2 = self.get_random_tensors(spec, alt_spec, 10, 10) self.assertFalse(cmp_op(st1, st2)) - st1 = _sharded_tensor.ones(spec, 10, 10) - st2 = _sharded_tensor.zeros(spec, 10, 10) + st1 = sharded_tensor.ones(spec, 10, 10) + st2 = sharded_tensor.zeros(spec, 10, 10) self.assertFalse(cmp_op(st1, st2)) - st1 = _sharded_tensor.ones(spec, 10, 10) - st2 = _sharded_tensor.ones(spec, 10, 10, dtype=torch.double) + st1 = sharded_tensor.ones(spec, 10, 10) + st2 = sharded_tensor.ones(spec, 10, 10, dtype=torch.double) self.assertFalse(cmp_op(st1, st2)) - st1 = _sharded_tensor.ones(spec, 10, 10) - st2 = _sharded_tensor.ones(spec, 10, 10, requires_grad=True) + st1 = sharded_tensor.ones(spec, 10, 10) + st2 = sharded_tensor.ones(spec, 10, 10, requires_grad=True) self.assertFalse(cmp_op(st1, st2)) cpu_spec = ChunkShardingSpec( @@ -100,8 +100,8 @@ class TestShardedTensorBinaryOps(ShardedTensorTestBase): "rank:3/cpu", ], ) - st1 = _sharded_tensor.ones(cpu_spec, 10, 10) - st2 = _sharded_tensor.ones(cpu_spec, 10, 10, pin_memory=True) + st1 = sharded_tensor.ones(cpu_spec, 10, 10) + st2 = sharded_tensor.ones(cpu_spec, 10, 10, pin_memory=True) self.assertFalse(cmp_op(st1, st2)) pg = dist.new_group([1, 0, 3, 2]) @@ -149,7 +149,7 @@ class TestShardedTensorBinaryOps(ShardedTensorTestBase): # compare different arrays st1, st2 = self.get_random_tensors(spec, spec, 10, 10, seed_offset=1) self.assertFalse(torch.allclose(st1, st2)) - # _sharded_tensor.rand produces uniform values in the [0,1] range. + # sharded_tensor.rand produces uniform values in the [0,1] range. self.assertTrue(torch.allclose(st1, st2, atol=1)) if __name__ == '__main__': diff --git a/test/distributed/_sharded_tensor/ops/test_embedding.py b/test/distributed/_shard/sharded_tensor/ops/test_embedding.py similarity index 96% rename from test/distributed/_sharded_tensor/ops/test_embedding.py rename to test/distributed/_shard/sharded_tensor/ops/test_embedding.py index 1a366148859..1fce221eab9 100644 --- a/test/distributed/_sharded_tensor/ops/test_embedding.py +++ b/test/distributed/_shard/sharded_tensor/ops/test_embedding.py @@ -4,7 +4,7 @@ import sys import torch import torch.distributed as dist -from torch.distributed._sharded_tensor import ( +from torch.distributed._shard import ( shard_parameter, ) from torch.testing._internal.common_distributed import ( @@ -15,12 +15,12 @@ from torch.testing._internal.common_utils import ( TEST_WITH_DEV_DBG_ASAN, run_tests, ) -from torch.testing._internal.distributed._sharded_tensor import ( +from torch.testing._internal.distributed._shard.sharded_tensor import ( TEST_GPU_NUM, ShardedTensorTestBase, with_comms, ) -from torch.testing._internal.distributed._sharded_tensor._test_ops_common import ( +from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import ( generate_chunk_sharding_specs_for_test, generate_local_weight_sharding_params_for_test, ) diff --git a/test/distributed/_sharded_tensor/ops/test_embedding_bag.py b/test/distributed/_shard/sharded_tensor/ops/test_embedding_bag.py similarity index 97% rename from test/distributed/_sharded_tensor/ops/test_embedding_bag.py rename to test/distributed/_shard/sharded_tensor/ops/test_embedding_bag.py index 9080fd58cd9..a9c38937b5a 100644 --- a/test/distributed/_sharded_tensor/ops/test_embedding_bag.py +++ b/test/distributed/_shard/sharded_tensor/ops/test_embedding_bag.py @@ -4,7 +4,7 @@ import sys import torch import torch.distributed as dist -from torch.distributed._sharded_tensor import ( +from torch.distributed._shard import ( shard_parameter, ) from torch.testing._internal.common_distributed import ( @@ -15,12 +15,12 @@ from torch.testing._internal.common_utils import ( TEST_WITH_DEV_DBG_ASAN, run_tests, ) -from torch.testing._internal.distributed._sharded_tensor import ( +from torch.testing._internal.distributed._shard.sharded_tensor import ( TEST_GPU_NUM, ShardedTensorTestBase, with_comms, ) -from torch.testing._internal.distributed._sharded_tensor._test_ops_common import ( +from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import ( generate_chunk_sharding_specs_for_test, generate_local_weight_sharding_params_for_test, ) diff --git a/test/distributed/_sharded_tensor/ops/test_init.py b/test/distributed/_shard/sharded_tensor/ops/test_init.py similarity index 70% rename from test/distributed/_sharded_tensor/ops/test_init.py rename to test/distributed/_shard/sharded_tensor/ops/test_init.py index 92860ac2fc9..6cbfd04b210 100644 --- a/test/distributed/_sharded_tensor/ops/test_init.py +++ b/test/distributed/_shard/sharded_tensor/ops/test_init.py @@ -3,15 +3,15 @@ import sys import torch -from torch.distributed import _sharded_tensor -from torch.distributed._sharding_spec import ( +from torch.distributed._shard import sharded_tensor +from torch.distributed._shard.sharding_spec import ( ChunkShardingSpec, ) from torch.testing._internal.common_distributed import ( requires_nccl, skip_if_lt_x_gpu, ) -from torch.testing._internal.distributed._sharded_tensor import ( +from torch.testing._internal.distributed._shard.sharded_tensor import ( ShardedTensorTestBase, with_comms, ) @@ -50,17 +50,17 @@ class TestShardedTensorNNInit(ShardedTensorTestBase): seed = 1234 dtype = torch.double - sharded_tensor = _sharded_tensor.empty(spec, h, w, dtype=dtype) - self.assertEqual(1, len(sharded_tensor.local_shards())) + st = sharded_tensor.empty(spec, h, w, dtype=dtype) + self.assertEqual(1, len(st.local_shards())) # Clone local tensor to ensure torch.nn.init starts from the same input - local_tensor_clone = torch.clone(sharded_tensor.local_shards()[0].tensor) + local_tensor_clone = torch.clone(st.local_shards()[0].tensor) torch.manual_seed(seed) - torch.nn.init.uniform_(sharded_tensor, a=a, b=b) + torch.nn.init.uniform_(st, a=a, b=b) torch.manual_seed(seed) torch.nn.init.uniform_(local_tensor_clone, a=a, b=b) - self.assertEqual(local_tensor_clone, sharded_tensor.local_shards()[0].tensor) + self.assertEqual(local_tensor_clone, st.local_shards()[0].tensor) @with_comms @skip_if_lt_x_gpu(4) @@ -85,17 +85,17 @@ class TestShardedTensorNNInit(ShardedTensorTestBase): seed = 1234 dtype = torch.double - sharded_tensor = _sharded_tensor.empty(spec, h, w, dtype=dtype) - self.assertEqual(1, len(sharded_tensor.local_shards())) + st = sharded_tensor.empty(spec, h, w, dtype=dtype) + self.assertEqual(1, len(st.local_shards())) # Clone local tensor to ensure torch.nn.init starts from the same input - local_tensor_clone = torch.clone(sharded_tensor.local_shards()[0].tensor) + local_tensor_clone = torch.clone(st.local_shards()[0].tensor) torch.manual_seed(seed) - torch.nn.init.normal_(sharded_tensor, mean=mean, std=std) + torch.nn.init.normal_(st, mean=mean, std=std) torch.manual_seed(seed) torch.nn.init.normal_(local_tensor_clone, mean=mean, std=std) - self.assertEqual(local_tensor_clone, sharded_tensor.local_shards()[0].tensor) + self.assertEqual(local_tensor_clone, st.local_shards()[0].tensor) @with_comms @skip_if_lt_x_gpu(4) @@ -120,17 +120,17 @@ class TestShardedTensorNNInit(ShardedTensorTestBase): seed = 1234 dtype = torch.double - sharded_tensor = _sharded_tensor.empty(spec, h, w, dtype=dtype) - self.assertEqual(1, len(sharded_tensor.local_shards())) + st = sharded_tensor.empty(spec, h, w, dtype=dtype) + self.assertEqual(1, len(st.local_shards())) # Clone local tensor to ensure torch.nn.init starts from the same input - local_tensor_clone = torch.clone(sharded_tensor.local_shards()[0].tensor) + local_tensor_clone = torch.clone(st.local_shards()[0].tensor) torch.manual_seed(seed) - torch.nn.init.kaiming_uniform_(sharded_tensor, a=a, mode=mode, nonlinearity=nonlinearity) + torch.nn.init.kaiming_uniform_(st, a=a, mode=mode, nonlinearity=nonlinearity) torch.manual_seed(seed) torch.nn.init.kaiming_uniform_(local_tensor_clone, a=a, mode=mode, nonlinearity=nonlinearity) - self.assertEqual(local_tensor_clone, sharded_tensor.local_shards()[0].tensor) + self.assertEqual(local_tensor_clone, st.local_shards()[0].tensor) if __name__ == '__main__': run_tests() diff --git a/test/distributed/_sharded_tensor/ops/test_linear.py b/test/distributed/_shard/sharded_tensor/ops/test_linear.py similarity index 96% rename from test/distributed/_sharded_tensor/ops/test_linear.py rename to test/distributed/_shard/sharded_tensor/ops/test_linear.py index 352dd11707f..398e72e2ae3 100644 --- a/test/distributed/_sharded_tensor/ops/test_linear.py +++ b/test/distributed/_shard/sharded_tensor/ops/test_linear.py @@ -4,16 +4,16 @@ import sys import torch import torch.distributed as dist -from torch.distributed._sharded_tensor import ( +from torch.distributed._shard import shard_parameter +from torch.distributed._shard.sharded_tensor import ( empty, - shard_parameter, ) -from torch.distributed._sharding_spec import ( +from torch.distributed._shard.sharding_spec import ( ChunkShardingSpec, EnumerableShardingSpec, ShardMetadata ) -from torch.distributed._sharded_optim import ( +from torch.distributed._shard.sharded_optim import ( ShardedOptimizer, named_params_with_sharded_tensor, ) @@ -25,12 +25,12 @@ from torch.testing._internal.common_utils import ( TEST_WITH_DEV_DBG_ASAN, run_tests, ) -from torch.testing._internal.distributed._sharded_tensor import ( +from torch.testing._internal.distributed._shard.sharded_tensor import ( TEST_GPU_NUM, ShardedTensorTestBase, with_comms, ) -from torch.testing._internal.distributed._sharded_tensor._test_ops_common import ( +from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import ( generate_chunk_sharding_specs_for_test, generate_local_weight_sharding_params_for_test, ) diff --git a/test/distributed/_sharded_tensor/test_sharded_tensor.py b/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py similarity index 81% rename from test/distributed/_sharded_tensor/test_sharded_tensor.py rename to test/distributed/_shard/sharded_tensor/test_sharded_tensor.py index a6486375af7..467096ccc9a 100644 --- a/test/distributed/_sharded_tensor/test_sharded_tensor.py +++ b/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py @@ -8,22 +8,24 @@ import sys import torch import torch.distributed as dist from torch.distributed import rpc -from torch.distributed import _sharded_tensor -from torch.distributed._sharded_tensor import ( +from torch.distributed._shard import sharded_tensor +from torch.distributed._shard import ( + shard_parameter, +) +from torch.distributed._shard.sharded_tensor import ( sharded_op_impl, load_with_process_group, pre_load_state_dict_hook, - shard_parameter, state_dict_hook, + ShardedTensor, ) -from torch.distributed._sharding_spec import ( +from torch.distributed._shard.sharding_spec import ( ChunkShardingSpec, EnumerableShardingSpec, - ShardMetadata + ShardMetadata, ) -from torch.distributed._sharded_tensor.api import ( +from torch.distributed._shard.sharded_tensor.api import ( CreateOp, - ShardedTensor, TensorInitParams, TensorProperties, _create_tensor_from_params, @@ -38,7 +40,7 @@ from torch.testing._internal.common_utils import ( run_tests, sandcastle_skip_if, ) -from torch.testing._internal.distributed._sharded_tensor import ( +from torch.testing._internal.distributed._shard.sharded_tensor import ( ShardedTensorTestBase, with_comms, ) @@ -52,7 +54,7 @@ class MyShardedModel2(torch.nn.Module): def __init__(self, spec=None, group=None): super(MyShardedModel2, self).__init__() if spec is not None: - self.sharded_tensor2 = _sharded_tensor.empty(spec, 10, 20, process_group=group, init_rrefs=True) + self.sharded_tensor2 = sharded_tensor.empty(spec, 10, 20, process_group=group, init_rrefs=True) else: self.sharded_tensor2 = None self.random_tensor2 = torch.nn.Parameter(torch.rand(2, 2)) @@ -62,7 +64,7 @@ class MyShardedModel1(torch.nn.Module): def __init__(self, spec=None, group=None): super(MyShardedModel1, self).__init__() if spec is not None: - self.sharded_tensor1 = _sharded_tensor.empty(spec, 10, 20, process_group=group, init_rrefs=True) + self.sharded_tensor1 = sharded_tensor.empty(spec, 10, 20, process_group=group, init_rrefs=True) else: self.sharded_tensor1 = None self.random_tensor1 = torch.nn.Parameter(torch.rand(2, 2)) @@ -106,10 +108,10 @@ class TestShardedTensorMetadata(TestCase): for tensor_properties_input in itertools.product(dtypes, layouts, requires_grads, memory_formats, pin_memories): dtype, layout, requires_grad, memory_format, pin_memory = tensor_properties_input - expected_st_metadata = _sharded_tensor.ShardedTensorMetadata( + expected_st_metadata = sharded_tensor.ShardedTensorMetadata( shard_metadatas, (10, 10), - _sharded_tensor.TensorProperties(dtype, layout, requires_grad, memory_format, pin_memory) + TensorProperties(dtype, layout, requires_grad, memory_format, pin_memory) ) pickled_obj = pickle.dumps(expected_st_metadata) @@ -263,7 +265,7 @@ class TestShardParameter(ShardedTensorTestBase): shard_parameter(fc, 'weight', spec) # Verify. - self.assertTrue(isinstance(fc.weight, _sharded_tensor.ShardedTensor)) + self.assertTrue(isinstance(fc.weight, ShardedTensor)) local_shards = fc.weight.local_shards() self.assertEqual(1, len(local_shards)) self.assertEqual(torch.Size([3, 12]), local_shards[0].tensor.size()) @@ -345,20 +347,20 @@ class TestShardedTensorChunked(ShardedTensorTestBase): ], ) - sharded_tensor = _sharded_tensor.empty(spec, 10, 20, init_rrefs=True) - sharded_tensor_metadata = sharded_tensor.metadata() - self.assertEqual(torch.Size([10, 20]), sharded_tensor_metadata.size) - self.assertEqual(torch.float, sharded_tensor.dtype) - self.assertEqual(torch.strided, sharded_tensor.layout) - self.assertEqual(False, sharded_tensor.requires_grad) - self.assertTrue(sharded_tensor.is_contiguous()) - self.assertFalse(sharded_tensor.is_pinned()) + st = sharded_tensor.empty(spec, 10, 20, init_rrefs=True) + st_metadata = st.metadata() + self.assertEqual(torch.Size([10, 20]), st_metadata.size) + self.assertEqual(torch.float, st.dtype) + self.assertEqual(torch.strided, st.layout) + self.assertEqual(False, st.requires_grad) + self.assertTrue(st.is_contiguous()) + self.assertFalse(st.is_pinned()) - sharded_tensor = _sharded_tensor.empty(spec, 10, 20, requires_grad=True, init_rrefs=True) - self.assertEqual(True, sharded_tensor.requires_grad) + st = sharded_tensor.empty(spec, 10, 20, requires_grad=True, init_rrefs=True) + self.assertEqual(True, st.requires_grad) - sharded_tensor = _sharded_tensor.empty(spec, 10, 20, dtype=torch.double, init_rrefs=True) - self.assertEqual(torch.double, sharded_tensor.dtype) + st = sharded_tensor.empty(spec, 10, 20, dtype=torch.double, init_rrefs=True) + self.assertEqual(torch.double, st.dtype) # Need CPU for pin_memory spec = ChunkShardingSpec( @@ -371,13 +373,13 @@ class TestShardedTensorChunked(ShardedTensorTestBase): ], ) - sharded_tensor = _sharded_tensor.empty(spec, 10, 20, pin_memory=True, init_rrefs=True) - self.assertEqual(True, sharded_tensor.is_pinned()) + st = sharded_tensor.empty(spec, 10, 20, pin_memory=True, init_rrefs=True) + self.assertEqual(True, st.is_pinned()) # test read only properties, they're read only as we can't simply change # the global metadata without changing the underlying shard's properties with self.assertRaisesRegex(AttributeError, "can't set attribute"): - sharded_tensor.requires_grad = True + st.requires_grad = True @with_comms @skip_if_lt_x_gpu(4) @@ -394,10 +396,10 @@ class TestShardedTensorChunked(ShardedTensorTestBase): "rank:3/cuda:3", ], ) - sharded_tensor = _sharded_tensor.empty(spec, 10, 20, init_rrefs=True) + st = sharded_tensor.empty(spec, 10, 20, init_rrefs=True) # Validate local shard. - local_shards = sharded_tensor.local_shards() + local_shards = st.local_shards() self.assertEqual(1, len(local_shards)) local_shard = local_shards[0].tensor self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) @@ -407,8 +409,8 @@ class TestShardedTensorChunked(ShardedTensorTestBase): self.assertEqual((3, 20), local_shard.size()) # Validate global metadata. - sharded_tensor_metadata = sharded_tensor.metadata() - shards_metadata = sharded_tensor_metadata.shards_metadata + st_metadata = st.metadata() + shards_metadata = st_metadata.shards_metadata self.assertEqual(4, len(shards_metadata)) for rank, shard_metadata in enumerate(shards_metadata): @@ -420,7 +422,7 @@ class TestShardedTensorChunked(ShardedTensorTestBase): self.assertEqual(f'rank:{rank}/cuda:{rank}', str(shard_metadata.placement)) # Validate remote shards. - remote_shards = sharded_tensor.remote_shards() + remote_shards = st.remote_shards() self.assertEqual(3, len(remote_shards)) for rpc_rank, shards in remote_shards.items(): @@ -439,7 +441,7 @@ class TestShardedTensorChunked(ShardedTensorTestBase): @skip_if_lt_x_gpu(4) @requires_nccl() def test_create_sharded_tensor_with_ones(self): - """ Test _sharded_tensor.ones(...) """ + """ Test sharded_tensor.ones(...) """ spec = ChunkShardingSpec( dim=0, @@ -451,10 +453,10 @@ class TestShardedTensorChunked(ShardedTensorTestBase): ], ) h, w = 10, 20 - sharded_tensor = _sharded_tensor.ones(spec, h, w) + st = sharded_tensor.ones(spec, h, w) # Validate local shard is initialized with torch.ones - local_shards = sharded_tensor.local_shards() + local_shards = st.local_shards() self.assertEqual(1, len(local_shards)) local_shard = local_shards[0].tensor self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) @@ -467,7 +469,7 @@ class TestShardedTensorChunked(ShardedTensorTestBase): @skip_if_lt_x_gpu(4) @requires_nccl() def test_gather_even(self) -> None: - """ Test _sharded_tensor.gather(...) with evenly distributed shards""" + """ Test _sharded_tensor.gather(...) with evenly distributed._shards""" spec = ChunkShardingSpec( dim=0, @@ -479,7 +481,7 @@ class TestShardedTensorChunked(ShardedTensorTestBase): ], ) h, w = 10, 20 - sharded_tensor = _sharded_tensor.ones(spec, h, w) + st = sharded_tensor.ones(spec, h, w) full_tensor = None dst = 1 @@ -489,7 +491,7 @@ class TestShardedTensorChunked(ShardedTensorTestBase): w, device=torch.device(f"cuda:{dst}"), ) - sharded_tensor.gather(dst, full_tensor) + st.gather(dst, full_tensor) if self.rank == dst: self.assertEqual(full_tensor, torch.ones(h, w)) @@ -500,7 +502,7 @@ class TestShardedTensorChunked(ShardedTensorTestBase): @skip_if_lt_x_gpu(4) @requires_nccl() def test_gather_uneven(self) -> None: - """ Test _sharded_tensor.gather(...) with unevenly distributed shards""" + """ Test _sharded_tensor.gather(...) with unevenly distributed._shards""" spec = ChunkShardingSpec( dim=0, @@ -513,7 +515,7 @@ class TestShardedTensorChunked(ShardedTensorTestBase): ], ) h, w = 10, 20 - sharded_tensor = _sharded_tensor.ones(spec, h, w) + st = sharded_tensor.ones(spec, h, w) full_tensor = None dst = 1 @@ -523,7 +525,7 @@ class TestShardedTensorChunked(ShardedTensorTestBase): w, device=torch.device(f"cuda:{dst}"), ) - sharded_tensor.gather(dst, full_tensor) + st.gather(dst, full_tensor) if self.rank == dst: self.assertEqual(full_tensor, torch.ones(h, w)) @@ -534,7 +536,7 @@ class TestShardedTensorChunked(ShardedTensorTestBase): @skip_if_lt_x_gpu(4) @requires_nccl() def test_create_sharded_tensor_with_zeros(self): - """ Test _sharded_tensor.zeros(...) """ + """ Test sharded_tensor.zeros(...) """ spec = ChunkShardingSpec( dim=0, @@ -546,10 +548,10 @@ class TestShardedTensorChunked(ShardedTensorTestBase): ], ) h, w = 10, 20 - sharded_tensor = _sharded_tensor.zeros(spec, h, w) + st = sharded_tensor.zeros(spec, h, w) # Validate local shard is initialized with torch.zeros - local_shards = sharded_tensor.local_shards() + local_shards = st.local_shards() self.assertEqual(1, len(local_shards)) local_shard = local_shards[0].tensor self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) @@ -563,7 +565,7 @@ class TestShardedTensorChunked(ShardedTensorTestBase): @skip_if_lt_x_gpu(4) @requires_nccl() def test_create_sharded_tensor_with_rand(self): - """ Test _sharded_tensor.rand(...) """ + """ Test sharded_tensor.rand(...) """ spec = ChunkShardingSpec( dim=0, @@ -584,10 +586,10 @@ class TestShardedTensorChunked(ShardedTensorTestBase): expected = torch.rand(expected_h, w, device=expected_device, dtype=dtype) # reset seed to ensure the same random numbers are generated torch.manual_seed(seed) - sharded_tensor = _sharded_tensor.rand(spec, h, w, dtype=dtype) + st = sharded_tensor.rand(spec, h, w, dtype=dtype) # Validate local shard is initialized with torch.rand - local_shards = sharded_tensor.local_shards() + local_shards = st.local_shards() self.assertEqual(1, len(local_shards)) local_shard = local_shards[0].tensor self.assertEqual(expected_device, local_shard.device) @@ -599,7 +601,7 @@ class TestShardedTensorChunked(ShardedTensorTestBase): @skip_if_lt_x_gpu(4) @requires_nccl() def test_create_sharded_tensor_with_full(self): - """ Test _sharded_tensor.full(...) """ + """ Test sharded_tensor.full(...) """ spec = ChunkShardingSpec( dim=0, @@ -612,10 +614,10 @@ class TestShardedTensorChunked(ShardedTensorTestBase): ) h, w = 10, 20 fill_value = 1234 - sharded_tensor = _sharded_tensor.full(spec, size=(h, w), fill_value=fill_value, dtype=torch.int32) + st = sharded_tensor.full(spec, size=(h, w), fill_value=fill_value, dtype=torch.int32) # Validate local shard is initialized with torch.full - local_shards = sharded_tensor.local_shards() + local_shards = st.local_shards() self.assertEqual(1, len(local_shards)) local_shard = local_shards[0].tensor self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) @@ -638,10 +640,10 @@ class TestShardedTensorChunked(ShardedTensorTestBase): "rank:3/cuda:3", ], ) - sharded_tensor = _sharded_tensor.empty(spec, 10, 20, init_rrefs=True) + st = sharded_tensor.empty(spec, 10, 20, init_rrefs=True) # Validate local shard. - local_shards = sharded_tensor.local_shards() + local_shards = st.local_shards() if self.rank >= 2: self.assertEqual(1, len(local_shards)) local_shard = local_shards[0].tensor @@ -651,8 +653,8 @@ class TestShardedTensorChunked(ShardedTensorTestBase): self.assertEqual(0, len(local_shards)) # Validate global metadata. - sharded_tensor_metadata = sharded_tensor.metadata() - shards_metadata = sharded_tensor_metadata.shards_metadata + st_metadata = st.metadata() + shards_metadata = st_metadata.shards_metadata self.assertEqual(2, len(shards_metadata)) for shard_rank, shard_metadata in enumerate(shards_metadata): @@ -661,7 +663,7 @@ class TestShardedTensorChunked(ShardedTensorTestBase): self.assertEqual(f'rank:{shard_rank + 2}/cuda:{shard_rank + 2}', str(shard_metadata.placement)) # Validate remote shards. - remote_shards = sharded_tensor.remote_shards() + remote_shards = st.remote_shards() if self.rank >= 2: self.assertEqual(1, len(remote_shards)) else: @@ -689,10 +691,10 @@ class TestShardedTensorChunked(ShardedTensorTestBase): ) pg = dist.new_group(ranks=[1, 2, 3]) - sharded_tensor = _sharded_tensor.empty(spec, 10, 20, process_group=pg, init_rrefs=True) + st = sharded_tensor.empty(spec, 10, 20, process_group=pg, init_rrefs=True) # Validate local shard. - local_shards = sharded_tensor.local_shards() + local_shards = st.local_shards() if self.rank >= 2: self.assertEqual(1, len(local_shards)) local_shard = local_shards[0].tensor @@ -702,8 +704,8 @@ class TestShardedTensorChunked(ShardedTensorTestBase): self.assertEqual(0, len(local_shards)) # Validate global metadata. - sharded_tensor_metadata = sharded_tensor.metadata() - shards_metadata = sharded_tensor_metadata.shards_metadata + st_metadata = st.metadata() + shards_metadata = st_metadata.shards_metadata self.assertEqual(2, len(shards_metadata)) for shard_rank, shard_metadata in enumerate(shards_metadata): @@ -712,7 +714,7 @@ class TestShardedTensorChunked(ShardedTensorTestBase): self.assertEqual(f'rank:{shard_rank + 1}/cuda:{shard_rank + 2}', str(shard_metadata.placement)) # Validate remote shards. - remote_shards = sharded_tensor.remote_shards() + remote_shards = st.remote_shards() if self.rank >= 2: self.assertEqual(1, len(remote_shards)) else: @@ -743,18 +745,18 @@ class TestShardedTensorChunked(ShardedTensorTestBase): "rank:3/cuda:3", ], ) - sharded_tensor = _sharded_tensor.empty(spec, 16, 20, init_rrefs=True) + st = sharded_tensor.empty(spec, 16, 20, init_rrefs=True) # Validate local shards. - local_shards = sharded_tensor.local_shards() + local_shards = st.local_shards() self.assertEqual(2, len(local_shards)) for local_shard in local_shards: self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.tensor.device) self.assertEqual((2, 20), local_shard.tensor.size()) # Validate global metadata. - sharded_tensor_metadata = sharded_tensor.metadata() - shards_metadata = sharded_tensor_metadata.shards_metadata + st_metadata = st.metadata() + shards_metadata = st_metadata.shards_metadata self.assertEqual(8, len(shards_metadata)) for shard_idx, shard_metadata in enumerate(shards_metadata): @@ -763,7 +765,7 @@ class TestShardedTensorChunked(ShardedTensorTestBase): self.assertEqual(f'rank:{shard_idx % 4}/cuda:{shard_idx % 4}', str(shard_metadata.placement)) # Validate remote shards. - remote_shards = sharded_tensor.remote_shards() + remote_shards = st.remote_shards() self.assertEqual(3, len(remote_shards)) owners = {} for rpc_rank, shards in remote_shards.items(): @@ -790,18 +792,18 @@ class TestShardedTensorChunked(ShardedTensorTestBase): ], ) - sharded_tensor = _sharded_tensor.empty(spec, 10, 32) + st = sharded_tensor.empty(spec, 10, 32) # Validate local shard. - local_shards = sharded_tensor.local_shards() + local_shards = st.local_shards() self.assertEqual(1, len(local_shards)) local_shard = local_shards[0].tensor self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) self.assertEqual((10, 8), local_shard.size()) # Validate global metadata. - sharded_tensor_metadata = sharded_tensor.metadata() - shards_metadata = sharded_tensor_metadata.shards_metadata + st_metadata = st.metadata() + shards_metadata = st_metadata.shards_metadata self.assertEqual(4, len(shards_metadata)) for rank, shard_metadata in enumerate(shards_metadata): @@ -816,47 +818,47 @@ class TestShardedTensorChunked(ShardedTensorTestBase): spec = ChunkShardingSpec(dim='H', placements=["rank:1/cuda:1"]) with self.assertRaisesRegex(ValueError, 'needs to be an integer'): - _sharded_tensor.empty(spec, 10, 20) + sharded_tensor.empty(spec, 10, 20) for dim in [2, 3, 4, -3, -4, -5]: spec = ChunkShardingSpec(dim=dim, placements=["rank:1/cuda:1"]) with self.assertRaisesRegex(ValueError, 'Invalid sharding dim'): - _sharded_tensor.empty(spec, 10, 20) + sharded_tensor.empty(spec, 10, 20) spec = ChunkShardingSpec(dim=0, placements=["rank:5/cuda:1"]) with self.assertRaisesRegex(ValueError, 'Invalid rank'): - _sharded_tensor.empty(spec, 10, 20) + sharded_tensor.empty(spec, 10, 20) spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"]) - sharded_tensor = _sharded_tensor.empty(spec, 10, 20) + st = sharded_tensor.empty(spec, 10, 20) tensor = torch.empty(10, 20) with self.assertRaisesRegex(RuntimeError, "not supported for ShardedTensor!"): - torch.add(sharded_tensor, tensor) + torch.add(st, tensor) spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"]) with self.assertRaisesRegex(ValueError, 'Only torch.strided layout is currently supported'): - _sharded_tensor.empty(spec, 10, 20, layout=torch.sparse) + sharded_tensor.empty(spec, 10, 20, layout=torch.sparse) spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"]) with self.assertRaisesRegex(ValueError, 'Only torch.contiguous_format memory_format is currently supported'): - _sharded_tensor.empty(spec, 10, 20, memory_format=torch.channels_last) + sharded_tensor.empty(spec, 10, 20, memory_format=torch.channels_last) spec = ChunkShardingSpec(dim=0, placements=["worker0/cuda:1"]) with self.assertRaisesRegex(RuntimeError, 'RPC framework needs to be initialized'): - _sharded_tensor.empty(spec, 10, 20) + sharded_tensor.empty(spec, 10, 20) spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"]) with self.assertRaisesRegex(RuntimeError, 'RPC Framework needs to be initialized'): - st = _sharded_tensor.empty(spec, 10, 20, init_rrefs=True) + st = sharded_tensor.empty(spec, 10, 20, init_rrefs=True) with self.assertRaisesRegex(RuntimeError, 'ShardedTensor created with init_rrefs=False'): - st = _sharded_tensor.empty(spec, 10, 20) + st = sharded_tensor.empty(spec, 10, 20) st.remote_shards() self.init_rpc() spec = ChunkShardingSpec(dim=0, placements=["workerfoo/cuda:1"]) with self.assertRaisesRegex(ValueError, 'Invalid worker name'): - _sharded_tensor.empty(spec, 10, 20, init_rrefs=True) + sharded_tensor.empty(spec, 10, 20, init_rrefs=True) @skip_if_lt_x_gpu(4) @requires_nccl() @@ -876,7 +878,7 @@ class TestShardedTensorChunked(ShardedTensorTestBase): spec = ChunkShardingSpec(dim=0, placements=["rank:1/cuda:1"]) with self.assertRaisesRegex(ValueError, 'Default ProcessGroup and RPC ranks must be the same'): - _sharded_tensor.empty(spec, 10, 20, init_rrefs=True) + sharded_tensor.empty(spec, 10, 20, init_rrefs=True) @skip_if_lt_x_gpu(4) @requires_nccl() @@ -892,10 +894,10 @@ class TestShardedTensorChunked(ShardedTensorTestBase): "rank:3/cuda:3", ], ) - sharded_tensor = _sharded_tensor.empty(spec, 2, 20) + st = sharded_tensor.empty(spec, 2, 20) # Validate local shard. - local_shards = sharded_tensor.local_shards() + local_shards = st.local_shards() if self.rank <= 1: self.assertEqual(1, len(local_shards)) local_shard = local_shards[0].tensor @@ -905,8 +907,8 @@ class TestShardedTensorChunked(ShardedTensorTestBase): self.assertEqual(0, len(local_shards)) # Validate global metadata. - sharded_tensor_metadata = sharded_tensor.metadata() - shards_metadata = sharded_tensor_metadata.shards_metadata + st_metadata = st.metadata() + shards_metadata = st_metadata.shards_metadata self.assertEqual(2, len(shards_metadata)) for shard_rank, shard_metadata in enumerate(shards_metadata): @@ -929,38 +931,38 @@ class TestShardedTensorChunked(ShardedTensorTestBase): ) # Test with *args - sharded_tensor = _sharded_tensor.empty(spec, 10, 20, init_rrefs=True) - self.assertEqual(torch.Size([10, 20]), sharded_tensor.size()) + st = sharded_tensor.empty(spec, 10, 20, init_rrefs=True) + self.assertEqual(torch.Size([10, 20]), st.size()) # Test with single *args - sharded_tensor = _sharded_tensor.empty(spec, 10, init_rrefs=True) - self.assertEqual(torch.Size([10]), sharded_tensor.size()) + st = sharded_tensor.empty(spec, 10, init_rrefs=True) + self.assertEqual(torch.Size([10]), st.size()) # Test with list - sharded_tensor = _sharded_tensor.empty(spec, [10, 20], init_rrefs=True) - self.assertEqual(torch.Size([10, 20]), sharded_tensor.size()) + st = sharded_tensor.empty(spec, [10, 20], init_rrefs=True) + self.assertEqual(torch.Size([10, 20]), st.size()) # Test with tuple - sharded_tensor = _sharded_tensor.empty(spec, (10, 20), init_rrefs=True) - self.assertEqual(torch.Size([10, 20]), sharded_tensor.size()) + st = sharded_tensor.empty(spec, (10, 20), init_rrefs=True) + self.assertEqual(torch.Size([10, 20]), st.size()) # Test with row size - sharded_tensor = _sharded_tensor.empty(spec, (10, 20), init_rrefs=True) - self.assertEqual(sharded_tensor.size(0), 10) + st = sharded_tensor.empty(spec, (10, 20), init_rrefs=True) + self.assertEqual(st.size(0), 10) # Test with col size - sharded_tensor = _sharded_tensor.empty(spec, (10, 20), init_rrefs=True) - self.assertEqual(sharded_tensor.size(1), 20) + st = sharded_tensor.empty(spec, (10, 20), init_rrefs=True) + self.assertEqual(st.size(1), 20) # Test with invalid input - sharded_tensor = _sharded_tensor.empty(spec, (10, 20), init_rrefs=True) + st = sharded_tensor.empty(spec, (10, 20), init_rrefs=True) with self.assertRaisesRegex(ValueError, 'must be within the range of tensor dimensions \\[0, 2\\)'): - sharded_tensor.size(-1) + st.size(-1) with self.assertRaisesRegex(ValueError, 'must be within the range of tensor dimensions \\[0, 2\\)'): - sharded_tensor.size(2) + st.size(2) with self.assertRaises(TypeError): - sharded_tensor = _sharded_tensor.empty(spec, 'foo') + st = sharded_tensor.empty(spec, 'foo') @with_comms @skip_if_lt_x_gpu(4) @@ -1119,11 +1121,11 @@ class TestShardedTensorChunked(ShardedTensorTestBase): "rank:3/cuda:3", ], ) - st1 = _sharded_tensor.empty(spec, 10, 20, init_rrefs=True) - st2 = _sharded_tensor.empty(spec, 10, 20) + st1 = sharded_tensor.empty(spec, 10, 20, init_rrefs=True) + st2 = sharded_tensor.empty(spec, 10, 20) create_tensors() - self.assertEqual(0, len(_sharded_tensor.api._sharded_tensor_map)) + self.assertEqual(0, len(sharded_tensor.api._sharded_tensor_map)) class TestShardedTensorEnumerable(ShardedTensorTestBase): @@ -1155,20 +1157,20 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): ) ]) - sharded_tensor = _sharded_tensor.empty(spec, 10, 10, init_rrefs=True) - sharded_tensor_metadata = sharded_tensor.metadata() - self.assertEqual(torch.Size([10, 10]), sharded_tensor_metadata.size) - self.assertEqual(torch.float, sharded_tensor.dtype) - self.assertEqual(torch.strided, sharded_tensor.layout) - self.assertEqual(False, sharded_tensor.requires_grad) - self.assertTrue(sharded_tensor.is_contiguous()) - self.assertFalse(sharded_tensor.is_pinned()) + st = sharded_tensor.empty(spec, 10, 10, init_rrefs=True) + st_metadata = st.metadata() + self.assertEqual(torch.Size([10, 10]), st_metadata.size) + self.assertEqual(torch.float, st.dtype) + self.assertEqual(torch.strided, st.layout) + self.assertEqual(False, st.requires_grad) + self.assertTrue(st.is_contiguous()) + self.assertFalse(st.is_pinned()) - sharded_tensor = _sharded_tensor.empty(spec, 10, 10, requires_grad=True, init_rrefs=True) - self.assertEqual(True, sharded_tensor.requires_grad) + st = sharded_tensor.empty(spec, 10, 10, requires_grad=True, init_rrefs=True) + self.assertEqual(True, st.requires_grad) - sharded_tensor = _sharded_tensor.empty(spec, 10, 10, dtype=torch.double, init_rrefs=True) - self.assertEqual(torch.double, sharded_tensor.dtype) + st = sharded_tensor.empty(spec, 10, 10, dtype=torch.double, init_rrefs=True) + self.assertEqual(torch.double, st.dtype) # Need CPU for pin_memory spec = EnumerableShardingSpec([ @@ -1194,8 +1196,8 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): ) ]) - sharded_tensor = _sharded_tensor.empty(spec, 10, 10, pin_memory=True, init_rrefs=True) - self.assertTrue(sharded_tensor.is_pinned()) + st = sharded_tensor.empty(spec, 10, 10, pin_memory=True, init_rrefs=True) + self.assertTrue(st.is_pinned()) @with_comms @skip_if_lt_x_gpu(4) @@ -1225,12 +1227,12 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): ) ]) - sharded_tensor = _sharded_tensor.empty(spec, 10, 10, init_rrefs=True) - self.assertEqual((10, 10), sharded_tensor.size()) - self.assertEqual(1, len(sharded_tensor.local_shards())) + st = sharded_tensor.empty(spec, 10, 10, init_rrefs=True) + self.assertEqual((10, 10), st.size()) + self.assertEqual(1, len(st.local_shards())) # Verify local shard. - local_shard = sharded_tensor.local_shards()[0] + local_shard = st.local_shards()[0] self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) @@ -1240,8 +1242,8 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): self.assertEqual(f'rank:{self.rank}/cuda:{self.rank}', str(local_shard.metadata.placement)) # Verify global metadata. - sharded_tensor_metadata = sharded_tensor.metadata() - shards_metadata = sharded_tensor_metadata.shards_metadata + st_metadata = st.metadata() + shards_metadata = st_metadata.shards_metadata self.assertEqual(4, len(shards_metadata)) for rank, shard_metadata in enumerate(shards_metadata): self.assertEqual((rank // 2 * 5, (rank % 2) * 5), shard_metadata.shard_offsets) @@ -1249,7 +1251,7 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): self.assertEqual(f'rank:{rank}/cuda:{rank}', str(shard_metadata.placement)) # Validate remote shards. - remote_shards = sharded_tensor.remote_shards() + remote_shards = st.remote_shards() self.assertEqual(3, len(remote_shards)) for rpc_rank, shards in remote_shards.items(): @@ -1263,7 +1265,7 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): @skip_if_lt_x_gpu(4) @requires_nccl() def test_create_sharded_tensor_with_ones(self): - """ Test _sharded_tensor.ones(...) """ + """ Test sharded_tensor.ones(...) """ spec = EnumerableShardingSpec([ ShardMetadata( @@ -1288,12 +1290,12 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): ) ]) - sharded_tensor = _sharded_tensor.ones(spec, 10, 10, init_rrefs=True) - self.assertEqual((10, 10), sharded_tensor.size()) - self.assertEqual(1, len(sharded_tensor.local_shards())) + st = sharded_tensor.ones(spec, 10, 10, init_rrefs=True) + self.assertEqual((10, 10), st.size()) + self.assertEqual(1, len(st.local_shards())) # Verify local shard is initialized with torch.ones - local_shard = sharded_tensor.local_shards()[0] + local_shard = st.local_shards()[0] self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) self.assertEqual(local_shard.tensor, torch.ones(5, 5)) @@ -1302,7 +1304,7 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): @skip_if_lt_x_gpu(4) @requires_nccl() def test_gather_even(self) -> None: - """ Test _sharded_tensor.gather(...) with evenly distributed shards""" + """ Test _sharded_tensor.gather(...) with evenly distributed._shards""" spec = EnumerableShardingSpec([ ShardMetadata( @@ -1328,7 +1330,7 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): ]) h, w = 10, 10 - sharded_tensor = _sharded_tensor.ones(spec, h, w, init_rrefs=True) + st = sharded_tensor.ones(spec, h, w, init_rrefs=True) full_tensor = None dst = 0 @@ -1338,7 +1340,7 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): w, device=torch.device(f"cuda:{dst}") ) - sharded_tensor.gather(dst, full_tensor) + st.gather(dst, full_tensor) if self.rank == dst: self.assertEqual(full_tensor, torch.ones(h, w)) @@ -1349,7 +1351,7 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): @skip_if_lt_x_gpu(4) @requires_nccl() def test_gather_uneven(self) -> None: - """ Test _sharded_tensor.gather(...) with unevenly distributed shards""" + """ Test _sharded_tensor.gather(...) with unevenly distributed._shards""" spec = EnumerableShardingSpec([ ShardMetadata( @@ -1375,7 +1377,7 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): ]) h, w = 10, 10 - sharded_tensor = _sharded_tensor.ones(spec, h, w, init_rrefs=True) + st = sharded_tensor.ones(spec, h, w, init_rrefs=True) full_tensor = None dst = 0 @@ -1385,7 +1387,7 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): w, device=torch.device(f"cuda:{dst}") ) - sharded_tensor.gather(dst, full_tensor) + st.gather(dst, full_tensor) if self.rank == dst: self.assertEqual(full_tensor, torch.ones(h, w)) @@ -1420,9 +1422,9 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): ), ]) - sharded_tensor = _sharded_tensor.empty(spec, 6, 6) - self.assertEqual((6, 6), sharded_tensor.size()) - self.assertEqual(1, len(sharded_tensor.local_shards())) + st = sharded_tensor.empty(spec, 6, 6) + self.assertEqual((6, 6), st.size()) + self.assertEqual(1, len(st.local_shards())) def verify_size(rank, tensor_dims): if rank == 0: @@ -1445,7 +1447,7 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): self.assertEqual((4, 4), offsets) # Verify local shard. - local_shard = sharded_tensor.local_shards()[0] + local_shard = st.local_shards()[0] self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) verify_size(self.rank, local_shard.tensor.size()) @@ -1455,8 +1457,8 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): self.assertEqual(f'rank:{self.rank}/cuda:{self.rank}', str(local_shard.metadata.placement)) # Verify global metadata. - sharded_tensor_metadata = sharded_tensor.metadata() - shards_metadata = sharded_tensor_metadata.shards_metadata + st_metadata = st.metadata() + shards_metadata = st_metadata.shards_metadata self.assertEqual(4, len(shards_metadata)) for rank, shard_metadata in enumerate(shards_metadata): verify_offsets(rank, shard_metadata.shard_offsets) @@ -1480,16 +1482,16 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): ), ]) - sharded_tensor = _sharded_tensor.empty(spec, 10, 5, init_rrefs=True) - self.assertEqual((10, 5), sharded_tensor.size()) + st = sharded_tensor.empty(spec, 10, 5, init_rrefs=True) + self.assertEqual((10, 5), st.size()) if self.rank <= 1: - self.assertEqual(1, len(sharded_tensor.local_shards())) + self.assertEqual(1, len(st.local_shards())) else: - self.assertEqual(0, len(sharded_tensor.local_shards())) + self.assertEqual(0, len(st.local_shards())) if self.rank <= 1: # Verify local shard. - local_shard = sharded_tensor.local_shards()[0] + local_shard = st.local_shards()[0] self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) @@ -1499,8 +1501,8 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): self.assertEqual(f'rank:{self.rank}/cuda:{self.rank}', str(local_shard.metadata.placement)) # Verify global metadata. - sharded_tensor_metadata = sharded_tensor.metadata() - shards_metadata = sharded_tensor_metadata.shards_metadata + st_metadata = st.metadata() + shards_metadata = st_metadata.shards_metadata self.assertEqual(2, len(shards_metadata)) for rank, shard_metadata in enumerate(shards_metadata): self.assertEqual((rank * 5, 0), shard_metadata.shard_offsets) @@ -1508,7 +1510,7 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): self.assertEqual(f'rank:{rank}/cuda:{rank}', str(shard_metadata.placement)) # Validate remote shards. - remote_shards = sharded_tensor.remote_shards() + remote_shards = st.remote_shards() if self.rank <= 1: self.assertEqual(1, len(remote_shards)) else: @@ -1541,11 +1543,11 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): pg = dist.new_group(ranks=[1, 2, 3]) - sharded_tensor = _sharded_tensor.empty(spec, 10, 5, process_group=pg, init_rrefs=True) - self.assertEqual((10, 5), sharded_tensor.size()) + st = sharded_tensor.empty(spec, 10, 5, process_group=pg, init_rrefs=True) + self.assertEqual((10, 5), st.size()) if self.rank == 1 or self.rank == 3: # Verify local shard. - local_shard = sharded_tensor.local_shards()[0] + local_shard = st.local_shards()[0] self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) @@ -1555,8 +1557,8 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): self.assertEqual(f'rank:{self.rank - 1}/cuda:{self.rank}', str(local_shard.metadata.placement)) # Verify global metadata. - sharded_tensor_metadata = sharded_tensor.metadata() - shards_metadata = sharded_tensor_metadata.shards_metadata + st_metadata = st.metadata() + shards_metadata = st_metadata.shards_metadata self.assertEqual(2, len(shards_metadata)) for rank, shard_metadata in enumerate(shards_metadata): self.assertEqual((rank * 5, 0), shard_metadata.shard_offsets) @@ -1564,7 +1566,7 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): self.assertEqual(f'rank:{rank * 2}/cuda:{rank * 2 + 1}', str(shard_metadata.placement)) # Validate remote shards. - remote_shards = sharded_tensor.remote_shards() + remote_shards = st.remote_shards() if self.rank == 1 or self.rank == 3: self.assertEqual(1, len(remote_shards)) else: @@ -1606,14 +1608,14 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): ) ]) - sharded_tensor = _sharded_tensor.empty(spec, 10, 10, init_rrefs=True) - self.assertEqual((10, 10), sharded_tensor.size()) + st = sharded_tensor.empty(spec, 10, 10, init_rrefs=True) + self.assertEqual((10, 10), st.size()) if self.rank <= 1: - self.assertEqual(2, len(sharded_tensor.local_shards())) + self.assertEqual(2, len(st.local_shards())) # Verify local shards. - for idx, local_shard in enumerate(sharded_tensor.local_shards()): + for idx, local_shard in enumerate(st.local_shards()): self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) @@ -1622,11 +1624,11 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): self.assertEqual((5, 5), local_shard.metadata.shard_sizes) self.assertEqual(f'rank:{self.rank}/cuda:{self.rank}', str(local_shard.metadata.placement)) else: - self.assertEqual(0, len(sharded_tensor.local_shards())) + self.assertEqual(0, len(st.local_shards())) # Verify global metadata. - sharded_tensor_metadata = sharded_tensor.metadata() - shards_metadata = sharded_tensor_metadata.shards_metadata + st_metadata = st.metadata() + shards_metadata = st_metadata.shards_metadata self.assertEqual(4, len(shards_metadata)) for shard_rank, shard_metadata in enumerate(shards_metadata): self.assertEqual((shard_rank // 2 * 5, (shard_rank % 2) * 5), shard_metadata.shard_offsets) @@ -1634,7 +1636,7 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): self.assertEqual(f'rank:{shard_rank % 2}/cuda:{shard_rank % 2}', str(shard_metadata.placement)) # Validate remote shards. - remote_shards = sharded_tensor.remote_shards() + remote_shards = st.remote_shards() if self.rank <= 1: self.assertEqual(1, len(remote_shards)) else: @@ -1675,12 +1677,12 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): ) ]) - sharded_tensor = _sharded_tensor.empty(spec, 10, 10, init_rrefs=True) - self.assertEqual((10, 10), sharded_tensor.size()) - self.assertEqual(1, len(sharded_tensor.local_shards())) + st = sharded_tensor.empty(spec, 10, 10, init_rrefs=True) + self.assertEqual((10, 10), st.size()) + self.assertEqual(1, len(st.local_shards())) # Verify local shard. - local_shard = sharded_tensor.local_shards()[0] + local_shard = st.local_shards()[0] self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) @@ -1690,8 +1692,8 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): self.assertEqual(f'worker{self.rank}/cuda:{self.rank}', str(local_shard.metadata.placement)) # Verify global metadata. - sharded_tensor_metadata = sharded_tensor.metadata() - shards_metadata = sharded_tensor_metadata.shards_metadata + st_metadata = st.metadata() + shards_metadata = st_metadata.shards_metadata self.assertEqual(4, len(shards_metadata)) for rank, shard_metadata in enumerate(shards_metadata): self.assertEqual((rank // 2 * 5, (rank % 2) * 5), shard_metadata.shard_offsets) @@ -1699,7 +1701,7 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase): self.assertEqual(f'worker{rank}/cuda:{rank}', str(shard_metadata.placement)) # Validate remote shards. - remote_shards = sharded_tensor.remote_shards() + remote_shards = st.remote_shards() self.assertEqual(3, len(remote_shards)) for rpc_rank, shards in remote_shards.items(): @@ -1724,8 +1726,8 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase): ) local_tensor = torch.randn(5, 5, device=f"cuda:{self.rank}") - local_shard = _sharded_tensor.Shard(local_tensor, local_shard_metadata) - local_shard_from_offsets = _sharded_tensor.Shard.from_tensor_and_offsets( + local_shard = sharded_tensor.Shard(local_tensor, local_shard_metadata) + local_shard_from_offsets = sharded_tensor.Shard.from_tensor_and_offsets( local_tensor, shard_offsets=shard_offsets, rank=self.rank @@ -1738,7 +1740,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase): placement=f"rank:{self.rank}/cuda:{self.rank}" ) with self.assertRaisesRegex(ValueError, 'Shard tensor size does not match'): - local_shard_from_wrong_meta = _sharded_tensor.Shard( + local_shard_from_wrong_meta = sharded_tensor.Shard( local_tensor, metadata=wrong_local_shard_metadata, ) @@ -1753,14 +1755,14 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase): placement=f"rank:{self.rank}/cuda:{self.rank}" ) - local_shards = [_sharded_tensor.Shard(torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata)] + local_shards = [sharded_tensor.Shard(torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata)] - sharded_tensor = _sharded_tensor.init_from_local_shards(local_shards, [10, 10], init_rrefs=True) - self.assertEqual((10, 10), sharded_tensor.size()) - self.assertEqual(1, len(sharded_tensor.local_shards())) + st = sharded_tensor.init_from_local_shards(local_shards, [10, 10], init_rrefs=True) + self.assertEqual((10, 10), st.size()) + self.assertEqual(1, len(st.local_shards())) # Verify local shard. - local_shard = sharded_tensor.local_shards()[0] + local_shard = st.local_shards()[0] self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) @@ -1770,7 +1772,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase): self.assertEqual(f'rank:{self.rank}/cuda:{self.rank}', str(local_shard.metadata.placement)) # Verify global metadata. - shards_metadata = sharded_tensor.metadata().shards_metadata + shards_metadata = st.metadata().shards_metadata self.assertEqual(4, len(shards_metadata)) for rank, shard_metadata in enumerate(shards_metadata): self.assertEqual((rank // 2 * 5, (rank % 2) * 5), shard_metadata.shard_offsets) @@ -1778,7 +1780,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase): self.assertEqual(f'rank:{rank}/cuda:{rank}', str(shard_metadata.placement)) # Validate remote shards. - remote_shards = sharded_tensor.remote_shards() + remote_shards = st.remote_shards() self.assertEqual(3, len(remote_shards)) for rpc_rank, shards in remote_shards.items(): @@ -1810,7 +1812,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase): placement=f"rank:{r}/cuda:{r}" )) - local_shards = [_sharded_tensor.Shard(torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata)] + local_shards = [sharded_tensor.Shard(torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata)] tensor_properties = TensorProperties( dtype=torch.get_default_dtype(), @@ -1820,22 +1822,22 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase): pin_memory=False, ) - sharded_tensor_metadata = _sharded_tensor.ShardedTensorMetadata( + sharded_tensor_metadata = sharded_tensor.ShardedTensorMetadata( shards_metadata=shards_metadata, size=torch.Size([10, 10]), tensor_properties=tensor_properties, ) - sharded_tensor = ShardedTensor._init_from_local_shards_and_global_metadata( + st = ShardedTensor._init_from_local_shards_and_global_metadata( local_shards, sharded_tensor_metadata, init_rrefs=True, ) - self.assertEqual((10, 10), sharded_tensor.size()) - self.assertEqual(1, len(sharded_tensor.local_shards())) + self.assertEqual((10, 10), st.size()) + self.assertEqual(1, len(st.local_shards())) # Verify local shard. - local_shard = sharded_tensor.local_shards()[0] + local_shard = st.local_shards()[0] self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) @@ -1845,7 +1847,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase): self.assertEqual(f'rank:{self.rank}/cuda:{self.rank}', str(local_shard.metadata.placement)) # Verify global metadata. - shards_metadata = sharded_tensor.metadata().shards_metadata + shards_metadata = st.metadata().shards_metadata self.assertEqual(4, len(shards_metadata)) for rank, shard_metadata in enumerate(shards_metadata): self.assertEqual((rank // 2 * 5, (rank % 2) * 5), shard_metadata.shard_offsets) @@ -1853,7 +1855,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase): self.assertEqual(f'rank:{rank}/cuda:{rank}', str(shard_metadata.placement)) # Validate remote shards. - remote_shards = sharded_tensor.remote_shards() + remote_shards = st.remote_shards() self.assertEqual(3, len(remote_shards)) for rpc_rank, shards in remote_shards.items(): @@ -1875,12 +1877,12 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase): shard_sizes=[5, 5], placement=f"rank:{self.rank - 1}/cuda:{self.rank}" ) - local_shards = [_sharded_tensor.Shard(torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata)] + local_shards = [sharded_tensor.Shard(torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata)] - sharded_tensor = _sharded_tensor.init_from_local_shards(local_shards, [15, 5], process_group=new_pg) + st = sharded_tensor.init_from_local_shards(local_shards, [15, 5], process_group=new_pg) # Verify local shard. - local_shard = sharded_tensor.local_shards()[0] + local_shard = st.local_shards()[0] self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) @@ -1890,8 +1892,8 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase): self.assertEqual(f'rank:{self.rank - 1}/cuda:{self.rank}', str(local_shard.metadata.placement)) # Verify global metadata. - sharded_tensor_metadata = sharded_tensor.metadata() - shards_metadata = sharded_tensor_metadata.shards_metadata + st_metadata = st.metadata() + shards_metadata = st_metadata.shards_metadata self.assertEqual(3, len(shards_metadata)) for rank, shard_metadata in enumerate(shards_metadata): self.assertEqual((rank * 5, 0), shard_metadata.shard_offsets) @@ -1915,27 +1917,27 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase): empty_local_shards = [] with self.assertRaisesRegex(ValueError, 'have no local shards on all ranks'): - sharded_tensor = _sharded_tensor.init_from_local_shards(empty_local_shards, [10, 10], init_rrefs=True) + st = sharded_tensor.init_from_local_shards(empty_local_shards, [10, 10], init_rrefs=True) wrong_layout_shards = [ - _sharded_tensor.Shard(sparse_tensor, local_shard_metadata) + sharded_tensor.Shard(sparse_tensor, local_shard_metadata) ] with self.assertRaisesRegex(ValueError, 'Only torch.strided layout is currently supported'): - sharded_tensor = _sharded_tensor.init_from_local_shards( + st = sharded_tensor.init_from_local_shards( wrong_layout_shards, [10, 10], init_rrefs=True) wrong_memory_format_shards = [ - _sharded_tensor.Shard(torch.randn(5, 5, device=f"cuda:{self.rank}").t(), local_shard_metadata) + sharded_tensor.Shard(torch.randn(5, 5, device=f"cuda:{self.rank}").t(), local_shard_metadata) ] with self.assertRaisesRegex(ValueError, 'Only torch.contiguous_format memory_format is currently supported'): - sharded_tensor = _sharded_tensor.init_from_local_shards( + st = sharded_tensor.init_from_local_shards( wrong_memory_format_shards, [10, 10], init_rrefs=True) with self.assertRaisesRegex(ValueError, 'Shard tensor size does not match'): - wrong_size_shards = [_sharded_tensor.Shard(torch.randn(2, 3, device=f"cuda:{self.rank}"), local_shard_metadata)] + wrong_size_shards = [sharded_tensor.Shard(torch.randn(2, 3, device=f"cuda:{self.rank}"), local_shard_metadata)] with self.assertRaisesRegex(ValueError, "Local shard tensor device does not match"): - wrong_device_shards = [_sharded_tensor.Shard(torch.randn(5, 5), local_shard_metadata)] + wrong_device_shards = [sharded_tensor.Shard(torch.randn(5, 5), local_shard_metadata)] @with_comms @skip_if_lt_x_gpu(4) @@ -1948,27 +1950,27 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase): ) tensor_overall_size = [10, 10] if self.rank == 0 else [10, 5] wrong_dtype_shards = [ - _sharded_tensor.Shard(torch.ones(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata) + sharded_tensor.Shard(torch.ones(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata) ] with self.assertRaisesRegex(ValueError, "ShardedTensor global_size property does not match from different ranks!"): - sharded_tensor = _sharded_tensor.init_from_local_shards(wrong_dtype_shards, tensor_overall_size, init_rrefs=True) + st = sharded_tensor.init_from_local_shards(wrong_dtype_shards, tensor_overall_size, init_rrefs=True) tensor_dtype = torch.int if self.rank == 0 else torch.float32 wrong_dtype_shards = [ - _sharded_tensor.Shard(torch.ones(5, 5, device=f"cuda:{self.rank}", dtype=tensor_dtype), local_shard_metadata) + sharded_tensor.Shard(torch.ones(5, 5, device=f"cuda:{self.rank}", dtype=tensor_dtype), local_shard_metadata) ] with self.assertRaisesRegex(ValueError, "ShardedTensor dtype property does not match from different ranks!"): - sharded_tensor = _sharded_tensor.init_from_local_shards(wrong_dtype_shards, [10, 10], init_rrefs=True) + st = sharded_tensor.init_from_local_shards(wrong_dtype_shards, [10, 10], init_rrefs=True) tensor_requires_grad = True if self.rank == 0 else False wrong_requires_grad_shards = [ - _sharded_tensor.Shard( + sharded_tensor.Shard( torch.randn(5, 5, device=f"cuda:{self.rank}", requires_grad=tensor_requires_grad), local_shard_metadata ) ] with self.assertRaisesRegex(ValueError, 'ShardedTensor requires_grad property does not match from different ranks!'): - sharded_tensor = _sharded_tensor.init_from_local_shards( + st = sharded_tensor.init_from_local_shards( wrong_requires_grad_shards, [10, 10], init_rrefs=True) local_shard_metadata = ShardMetadata( @@ -1987,19 +1989,19 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase): placement=f"rank:{self.rank}/cpu" ) wrong_pin_memory_local_shards = [ - _sharded_tensor.Shard(torch.randn(5, 5, pin_memory=True), local_shard_metadata), - _sharded_tensor.Shard(torch.randn(5, 5, pin_memory=False), local_shard_metadata) + sharded_tensor.Shard(torch.randn(5, 5, pin_memory=True), local_shard_metadata), + sharded_tensor.Shard(torch.randn(5, 5, pin_memory=False), local_shard_metadata) ] with self.assertRaisesRegex(ValueError, "Local shards' tensor pin_memory property need to be the same"): - sharded_tensor = _sharded_tensor.init_from_local_shards( + st = sharded_tensor.init_from_local_shards( wrong_pin_memory_local_shards, [10, 10], init_rrefs=True) tensor_pin_memory = True if self.rank == 0 else False wrong_pin_memory_shards_cross_ranks = [ - _sharded_tensor.Shard(torch.randn(5, 5, pin_memory=tensor_pin_memory), local_shard_metadata) + sharded_tensor.Shard(torch.randn(5, 5, pin_memory=tensor_pin_memory), local_shard_metadata) ] with self.assertRaisesRegex(ValueError, 'ShardedTensor pin_memory property does not match from different ranks!'): - sharded_tensor = _sharded_tensor.init_from_local_shards( + st = sharded_tensor.init_from_local_shards( wrong_pin_memory_shards_cross_ranks, [10, 10], init_rrefs=True) @@ -2014,10 +2016,10 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase): placement=f"rank:{self.rank}/cuda:{self.rank}" ) - local_shards = [_sharded_tensor.Shard(torch.randn(local_shard_size, device=f"cuda:{self.rank}"), local_shard_metadata)] + local_shards = [sharded_tensor.Shard(torch.randn(local_shard_size, device=f"cuda:{self.rank}"), local_shard_metadata)] with self.assertRaisesRegex(ValueError, "overlap"): - sharded_tensor = _sharded_tensor.init_from_local_shards(local_shards, [10, 10], init_rrefs=True) + sharded_tensor.init_from_local_shards(local_shards, [10, 10], init_rrefs=True) @with_comms @@ -2031,10 +2033,10 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase): placement=f"rank:{self.rank}/cuda:{self.rank}" ) - local_shards = [_sharded_tensor.Shard(torch.randn(local_shard_size, device=f"cuda:{self.rank}"), local_shard_metadata)] + local_shards = [sharded_tensor.Shard(torch.randn(local_shard_size, device=f"cuda:{self.rank}"), local_shard_metadata)] with self.assertRaisesRegex(ValueError, "does not match tensor volume"): - sharded_tensor = _sharded_tensor.init_from_local_shards(local_shards, [10, 10], init_rrefs=True) + sharded_tensor.init_from_local_shards(local_shards, [10, 10], init_rrefs=True) @with_comms @skip_if_lt_x_gpu(4) @@ -2065,7 +2067,7 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase): pin_memory=False, ) - sharded_tensor_metadata = _sharded_tensor.ShardedTensorMetadata( + sharded_tensor_metadata = sharded_tensor.ShardedTensorMetadata( shards_metadata=shards_metadata, size=torch.Size([10, 10]), tensor_properties=tensor_properties, @@ -2073,32 +2075,32 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase): empty_local_shards = [] with self.assertRaisesRegex(RuntimeError, 'does not match number of local shards metadata'): - sharded_tensor = ShardedTensor._init_from_local_shards_and_global_metadata( + ShardedTensor._init_from_local_shards_and_global_metadata( empty_local_shards, sharded_tensor_metadata ) wrong_num_shards = [ - _sharded_tensor.Shard(torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata), - _sharded_tensor.Shard(torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata) + sharded_tensor.Shard(torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata), + sharded_tensor.Shard(torch.randn(5, 5, device=f"cuda:{self.rank}"), local_shard_metadata) ] with self.assertRaisesRegex(RuntimeError, 'does not match number of local shards metadata'): - sharded_tensor = ShardedTensor._init_from_local_shards_and_global_metadata( + ShardedTensor._init_from_local_shards_and_global_metadata( wrong_num_shards, sharded_tensor_metadata ) with self.assertRaisesRegex(ValueError, 'Shard tensor size does not match with metadata.shard_lengths'): - wrong_size_shards = [_sharded_tensor.Shard(torch.randn(2, 3, device=f"cuda:{self.rank}"), local_shard_metadata)] + wrong_size_shards = [sharded_tensor.Shard(torch.randn(2, 3, device=f"cuda:{self.rank}"), local_shard_metadata)] with self.assertRaisesRegex(ValueError, "Local shard tensor device does not match with local Shard's placement"): - wrong_device_shards = [_sharded_tensor.Shard(torch.randn(5, 5), local_shard_metadata)] + wrong_device_shards = [sharded_tensor.Shard(torch.randn(5, 5), local_shard_metadata)] wrong_dtype_shards = [ - _sharded_tensor.Shard(torch.ones(5, 5, device=f"cuda:{self.rank}", dtype=torch.int), local_shard_metadata) + sharded_tensor.Shard(torch.ones(5, 5, device=f"cuda:{self.rank}", dtype=torch.int), local_shard_metadata) ] with self.assertRaisesRegex(ValueError, "Local shards' tensor dtype property is incompatible with"): - sharded_tensor = ShardedTensor._init_from_local_shards_and_global_metadata( + ShardedTensor._init_from_local_shards_and_global_metadata( wrong_dtype_shards, sharded_tensor_metadata ) @@ -2108,38 +2110,38 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase): sparse_tensor = torch.sparse_coo_tensor(indices, values, (5, 5), device=f"cuda:{self.rank}") wrong_layout_shards = [ - _sharded_tensor.Shard(sparse_tensor, local_shard_metadata) + sharded_tensor.Shard(sparse_tensor, local_shard_metadata) ] with self.assertRaisesRegex(ValueError, "Local shards' tensor layout property is incompatible with"): - sharded_tensor = ShardedTensor._init_from_local_shards_and_global_metadata( + ShardedTensor._init_from_local_shards_and_global_metadata( wrong_layout_shards, sharded_tensor_metadata ) wrong_requires_grad_shards = [ - _sharded_tensor.Shard(torch.randn(5, 5, device=f"cuda:{self.rank}", requires_grad=True), local_shard_metadata) + sharded_tensor.Shard(torch.randn(5, 5, device=f"cuda:{self.rank}", requires_grad=True), local_shard_metadata) ] with self.assertRaisesRegex(ValueError, "Local shards' tensor requires_grad property is incompatible with"): - sharded_tensor = ShardedTensor._init_from_local_shards_and_global_metadata( + ShardedTensor._init_from_local_shards_and_global_metadata( wrong_requires_grad_shards, sharded_tensor_metadata ) wrong_memory_format_shards = [ - _sharded_tensor.Shard(torch.randn(5, 5, device=f"cuda:{self.rank}").t(), local_shard_metadata) + sharded_tensor.Shard(torch.randn(5, 5, device=f"cuda:{self.rank}").t(), local_shard_metadata) ] with self.assertRaisesRegex(ValueError, 'Only torch.contiguous_format memory_format is currently supported'): - sharded_tensor = ShardedTensor._init_from_local_shards_and_global_metadata( + ShardedTensor._init_from_local_shards_and_global_metadata( wrong_memory_format_shards, sharded_tensor_metadata ) # pin_memory can only be on CPU local_shard_metadata.placement = _remote_device(f"rank:{self.rank}/cpu") wrong_pin_memory_shards = [ - _sharded_tensor.Shard(torch.randn(5, 5, pin_memory=True), local_shard_metadata) + sharded_tensor.Shard(torch.randn(5, 5, pin_memory=True), local_shard_metadata) ] with self.assertRaisesRegex(ValueError, "Local shards' tensor pin_memory property is incompatible with"): - sharded_tensor = ShardedTensor._init_from_local_shards_and_global_metadata( + ShardedTensor._init_from_local_shards_and_global_metadata( wrong_pin_memory_shards, sharded_tensor_metadata ) @@ -2165,7 +2167,7 @@ class TestShardedTensorCustomOps(ShardedTensorTestBase): ], ) - st = _sharded_tensor.rand(spec, 10, 10) + st = sharded_tensor.rand(spec, 10, 10) res = torch.asin(st) self.assertEqual(res, torch.asin(st.local_shards()[0].tensor)) diff --git a/test/distributed/_sharding_spec/test_sharding_spec.py b/test/distributed/_shard/sharding_spec/test_sharding_spec.py similarity index 98% rename from test/distributed/_sharding_spec/test_sharding_spec.py rename to test/distributed/_shard/sharding_spec/test_sharding_spec.py index 66dc0468e5a..079f91dee37 100644 --- a/test/distributed/_sharding_spec/test_sharding_spec.py +++ b/test/distributed/_shard/sharding_spec/test_sharding_spec.py @@ -2,13 +2,13 @@ import torch from torch.testing._internal.common_utils import TestCase -from torch.distributed._sharding_spec import ( +from torch.distributed._shard.sharding_spec import ( ChunkShardingSpec, DevicePlacementSpec, EnumerableShardingSpec, ShardMetadata, ) -from torch.distributed._sharding_spec._internals import ( +from torch.distributed._shard.sharding_spec._internals import ( check_tensor, get_split_size, get_chunked_dim_size, diff --git a/test/run_test.py b/test/run_test.py index f594ad234b8..9079bbc7a0d 100644 --- a/test/run_test.py +++ b/test/run_test.py @@ -202,13 +202,14 @@ WINDOWS_BLOCKLIST = [ "distributed/pipeline/sync/test_worker", "distributed/elastic/agent/server/test/api_test", "distributed/elastic/multiprocessing/api_test", - "distributed/_sharded_tensor/test_sharded_tensor", - "distributed/_sharded_tensor/ops/test_embedding", - "distributed/_sharded_tensor/ops/test_embedding_bag", - "distributed/_sharded_tensor/ops/test_binary_cmp", - "distributed/_sharded_tensor/ops/test_init", - "distributed/_sharded_tensor/ops/test_linear", - "distributed/_sharded_optim/test_sharded_optim", + "distributed/_shard/sharding_spec/test_sharding_spec", + "distributed/_shard/sharded_tensor/test_sharded_tensor", + "distributed/_shard/sharded_tensor/ops/test_embedding", + "distributed/_shard/sharded_tensor/ops/test_embedding_bag", + "distributed/_shard/sharded_tensor/ops/test_binary_cmp", + "distributed/_shard/sharded_tensor/ops/test_init", + "distributed/_shard/sharded_tensor/ops/test_linear", + "distributed/_shard/sharded_optim/test_sharded_optim", ] + FSDP_TEST + FX2TRT_TESTS ROCM_BLOCKLIST = [ @@ -216,13 +217,13 @@ ROCM_BLOCKLIST = [ "distributed/rpc/test_faulty_agent", "distributed/rpc/test_tensorpipe_agent", "distributed/rpc/cuda/test_tensorpipe_agent", - "distributed/_sharded_tensor/test_sharded_tensor", - "distributed/_sharded_tensor/ops/test_embedding", - "distributed/_sharded_tensor/ops/test_embedding_bag", - "distributed/_sharded_tensor/ops/test_binary_cmp", - "distributed/_sharded_tensor/ops/test_init", - "distributed/_sharded_tensor/ops/test_linear", - "distributed/_sharded_optim/test_sharded_optim", + "distributed/_shard/sharded_tensor/test_sharded_tensor", + "distributed/_shard/sharded_tensor/ops/test_embedding", + "distributed/_shard/sharded_tensor/ops/test_embedding_bag", + "distributed/_shard/sharded_tensor/ops/test_binary_cmp", + "distributed/_shard/sharded_tensor/ops/test_init", + "distributed/_shard/sharded_tensor/ops/test_linear", + "distributed/_shard/sharded_optim/test_sharded_optim", "test_determination", "test_multiprocessing", "test_jit_legacy", @@ -354,14 +355,14 @@ DISTRIBUTED_TESTS = [ "distributed/elastic/utils/util_test", "distributed/elastic/utils/distributed_test", "distributed/elastic/multiprocessing/api_test", - "distributed/_sharding_spec/test_sharding_spec", - "distributed/_sharded_tensor/test_sharded_tensor", - "distributed/_sharded_tensor/ops/test_embedding", - "distributed/_sharded_tensor/ops/test_embedding_bag", - "distributed/_sharded_tensor/ops/test_binary_cmp", - "distributed/_sharded_tensor/ops/test_init", - "distributed/_sharded_tensor/ops/test_linear", - "distributed/_sharded_optim/test_sharded_optim", + "distributed/_shard/sharding_spec/test_sharding_spec", + "distributed/_shard/sharded_tensor/test_sharded_tensor", + "distributed/_shard/sharded_tensor/ops/test_embedding", + "distributed/_shard/sharded_tensor/ops/test_embedding_bag", + "distributed/_shard/sharded_tensor/ops/test_binary_cmp", + "distributed/_shard/sharded_tensor/ops/test_init", + "distributed/_shard/sharded_tensor/ops/test_linear", + "distributed/_shard/sharded_optim/test_sharded_optim", ] + [test for test in TESTS if test.startswith("distributed/fsdp")] # Dictionary matching test modules (in TESTS) to lists of test cases (within that test_module) that would be run when diff --git a/torch/distributed/_shard/__init__.py b/torch/distributed/_shard/__init__.py new file mode 100644 index 00000000000..ffa0f4b851c --- /dev/null +++ b/torch/distributed/_shard/__init__.py @@ -0,0 +1 @@ +from .api import shard_parameter diff --git a/torch/distributed/_shard/api.py b/torch/distributed/_shard/api.py new file mode 100644 index 00000000000..c5e9060e7b9 --- /dev/null +++ b/torch/distributed/_shard/api.py @@ -0,0 +1,145 @@ +import copy +import torch +import torch.distributed as dist +from torch.distributed import distributed_c10d +from .sharding_spec import ( + ChunkShardingSpec, + ShardingSpec, +) +from torch.distributed._shard.sharding_spec._internals import ( + get_chunked_dim_size, + get_split_size, +) +from torch.distributed._shard.sharded_tensor import ( + Shard, + ShardMetadata, + ShardedTensor, +) + +def shard_parameter( + module: torch.nn.Module, + param_name: str, + sharding_spec: ShardingSpec, + src_rank=0, + process_group=None): + """ + Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that + module, it shards that parameter according to the provided + ``sharding_spec``. ``src_rank`` denotes the source rank which would be + used as the ground truth of the data which would be scattered as shards + across the rest of the ranks. + + This method replaces ``module.param_name`` with a + :class:`torch.distributed._shard.sharded_tensor.ShardedTensor` + + Args: + module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded. + param_name (str): Name of the parameter of ``module`` that needs to be sharded. + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + + Keyword args: + src_rank (int, optional): The source rank which is used as the ground truth of + the data for the parameter that would be sharded and scattered + across the rest of the ranks. + Default: 0. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + .. warning:: + Only :class:`torch.distributed._shard.sharding_spec.ShardingSpec` is + currently supported as the ``sharding_spec``. + """ + # Perform some validation first. + if not isinstance(sharding_spec, ChunkShardingSpec): + raise ValueError('Only ChunkShardingspec is supported.') + + if not hasattr(module, param_name): + raise ValueError(f'module: {module} does not have parameter with name: {param_name}') + + tensor = getattr(module, param_name) + if not isinstance(tensor, torch.Tensor): + raise ValueError(f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}') + + if not tensor.is_contiguous(): + raise ValueError(f'param: {param_name} is not a contiguous Tensor') + + pg = process_group if process_group is not None else distributed_c10d._get_default_group() + world_size = dist.get_world_size(pg) + rank = dist.get_rank(pg) + + # Validate src_rank and sharding_spec are same across all ranks. + gathered_list = [None] * world_size + dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg) + + for idx, entry in enumerate(gathered_list): + if src_rank != entry[0]: # type: ignore[index] + raise ValueError( + f'src_rank={src_rank} on rank: {rank} does not ' # type: ignore[index] + f'match with src_rank={entry[0]} on rank: {idx}') + if sharding_spec != entry[1]: # type: ignore[index] + raise ValueError( + f'sharding_spec={sharding_spec} on rank: {rank} does not ' # type: ignore[index] + f'match with sharding_spec={entry[1]} on rank: {idx}') + + # Rearrange chunks according to placement. + local_metadata = None + current_offsets = [0] * len(tensor.size()) + shards_metadata = [] + sharding_dim_size = tensor.size(sharding_spec.dim) # type: ignore[arg-type] + split_size = get_split_size(sharding_dim_size, world_size) + tensor_sizes = list(tensor.size()) + for idx, placement in enumerate(sharding_spec.placements): + chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size, idx) + shard_size = copy.deepcopy(tensor_sizes) + shard_size[sharding_spec.dim] = chunked_dim_size # type: ignore[index] + + shard_metadata = ShardMetadata( + shard_offsets=copy.deepcopy(current_offsets), + shard_sizes=shard_size, + placement=placement, + ) + shards_metadata.append(shard_metadata) + + if rank == placement.rank(): # type: ignore[union-attr] + local_metadata = shard_metadata + + current_offsets[sharding_spec.dim] += chunked_dim_size # type: ignore[index] + + # Scatter the shards (use broadcast since NCCL doesn't support scatter, this is very inefficient). + dist.broadcast(tensor, src=src_rank, group=pg) + + # Reshape to get shard for this rank and we don't want autograd + # recording here for the narrow op and 'local_shard' should be a + # leaf variable in the autograd graph. + local_shard = tensor.narrow( + sharding_spec.dim, # type: ignore[arg-type] + local_metadata.shard_offsets[sharding_spec.dim], # type: ignore[union-attr, arg-type, index] + local_metadata.shard_sizes[sharding_spec.dim], # type: ignore[union-attr, index] + ).clone().detach().contiguous() + + # Sync requires_grad to local_shard. + local_shard.requires_grad = tensor.requires_grad + + # Create ShardedTensor based on local shards. + local_shards = [ + Shard( + tensor=local_shard, + metadata=local_metadata, # type: ignore[arg-type] + ) + ] + + st = ShardedTensor._init_from_local_shards(local_shards, tensor.size(), process_group=pg) + + # Manually set sharding_spec + st._sharding_spec = sharding_spec + + # Replace param with ShardedTensor. + + # Need to delete the attribute first since param_name might be + # torch.nn.Parameter and can't be replaced with ShardedTensor which is + # not torch.nn.Parameter. + delattr(module, param_name) + + # Now we can set the attribute appropriately. + setattr(module, param_name, st) diff --git a/torch/distributed/_sharded_optim/__init__.py b/torch/distributed/_shard/sharded_optim/__init__.py similarity index 93% rename from torch/distributed/_sharded_optim/__init__.py rename to torch/distributed/_shard/sharded_optim/__init__.py index d21b10b72c8..e3cc7309bae 100644 --- a/torch/distributed/_sharded_optim/__init__.py +++ b/torch/distributed/_shard/sharded_optim/__init__.py @@ -3,7 +3,7 @@ from .api import ShardedOptimizer import torch.nn as nn -from torch.distributed._sharded_tensor import ( +from torch.distributed._shard.sharded_tensor import ( ShardedTensor ) @@ -16,7 +16,7 @@ def named_params_with_sharded_tensor( r"""Returns an iterator over module parameters (together with the ShardedTensor parameters), yielding both the name of the parameter as well as the parameter itself. This is typically passed to a - :class:torch.distributed._sharded_optim.ShardedOptimizer + :class:torch.distributed._shard.sharded_optim.ShardedOptimizer Args: prefix (str): prefix to prepend to all parameter names. diff --git a/torch/distributed/_sharded_optim/api.py b/torch/distributed/_shard/sharded_optim/api.py similarity index 98% rename from torch/distributed/_sharded_optim/api.py rename to torch/distributed/_shard/sharded_optim/api.py index fc7ed727add..7accc82754a 100644 --- a/torch/distributed/_sharded_optim/api.py +++ b/torch/distributed/_shard/sharded_optim/api.py @@ -2,7 +2,7 @@ from typing import List, Union, Mapping, Dict, Any import torch.optim as optim from torch import Tensor -from torch.distributed._sharded_tensor import ShardedTensor +from torch.distributed._shard.sharded_tensor import ShardedTensor class ShardedOptimizer(optim.Optimizer): diff --git a/torch/distributed/_shard/sharded_tensor/__init__.py b/torch/distributed/_shard/sharded_tensor/__init__.py new file mode 100644 index 00000000000..ed05d36c5e4 --- /dev/null +++ b/torch/distributed/_shard/sharded_tensor/__init__.py @@ -0,0 +1,412 @@ +# coding=utf-8 + +import copy +import functools +import torch +from torch.distributed._shard.sharding_spec import ( + ChunkShardingSpec, + ShardingSpec, +) +from torch.distributed._shard.sharding_spec._internals import ( + get_chunked_dim_size, + get_split_size, +) +from typing import List + +from .api import ( + _register_sharded_op, + CreateOp, + Shard, + ShardMetadata, + ShardedTensor, + ShardedTensorMetadata, + TensorInitParams, + TensorProperties, +) +from .utils import load_with_process_group +import torch.distributed as dist +from torch.distributed import distributed_c10d + + +def empty(sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False) -> ShardedTensor: + """ + Returns a :class:`ShardedTensor` filled with uninitialized data. + Needs to be called on all ranks in an SPMD fashion. + + Args: + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a sequence of integers defining the shape of the output + tensor. Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + tensor_properties = TensorProperties(dtype=dtype, layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, memory_format=memory_format, ) + tensor_init_params = TensorInitParams(create_op=CreateOp.EMPTY, tensor_properties=tensor_properties, ) + return ShardedTensor( + sharding_spec, + *size, + tensor_init_params=tensor_init_params, + process_group=process_group, + init_rrefs=init_rrefs, + ) + +def ones(sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False) -> ShardedTensor: + """ + Returns a :class:`ShardedTensor` with the scalar value 1. + Needs to be called on all ranks in an SPMD fashion. + + Args: + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a sequence of integers defining the shape of the output + tensor. Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + tensor_properties = TensorProperties(dtype=dtype, layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, memory_format=memory_format, ) + tensor_init_params = TensorInitParams(create_op=CreateOp.ONES, tensor_properties=tensor_properties) + return ShardedTensor( + sharding_spec, + *size, + tensor_init_params=tensor_init_params, + process_group=process_group, + init_rrefs=init_rrefs, + ) + + +def rand(sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False) -> ShardedTensor: + """ + Returns a :class:`ShardedTensor` filled with random numbers from a uniform distribution on the + interval :math:`[0, 1)`. Needs to be called on all ranks in an SPMD fashion. + + Args: + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a sequence of integers defining the shape of the output + tensor. Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + tensor_properties = TensorProperties( + dtype=dtype, layout=layout, requires_grad=requires_grad, + pin_memory=pin_memory, memory_format=memory_format + ) + tensor_init_params = TensorInitParams(create_op=CreateOp.RAND, tensor_properties=tensor_properties, ) + return ShardedTensor( + sharding_spec, + *size, + tensor_init_params=tensor_init_params, + process_group=process_group, + init_rrefs=init_rrefs, + ) + + +def zeros(sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False) -> ShardedTensor: + """ + Returns a :class:`ShardedTensor` filled with the scalar value 0. + Needs to be called on all ranks in an SPMD fashion. + + Args: + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a sequence of integers defining the shape of the output + tensor. Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + tensor_properties = TensorProperties( + dtype=dtype, layout=layout, requires_grad=requires_grad, + pin_memory=pin_memory, memory_format=memory_format, + ) + tensor_init_params = TensorInitParams(create_op=CreateOp.ZEROS, tensor_properties=tensor_properties, ) + return ShardedTensor( + sharding_spec, + *size, + tensor_init_params=tensor_init_params, + process_group=process_group, + init_rrefs=init_rrefs, + ) + + +def full(sharding_spec: ShardingSpec, + size, + fill_value=torch.types.Number, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False) -> ShardedTensor: + """ + Creates a :class:`ShardedTensor` filled with fill_value. The tensor’s dtype + is inferred from fill_value. If dtype is specified, it will override the + inferred type from fill_value. Needs to be called on all ranks in an SPMD fashion. + + Args: + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a list, tuple, or `torch.Size` of integers defining the shape of the + output tensor. + fill_value (Scalar) – the value to fill the output tensor with. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + tensor_properties = TensorProperties( + dtype=dtype, layout=layout, requires_grad=requires_grad, + pin_memory=pin_memory, memory_format=memory_format, + ) + tensor_init_params = TensorInitParams( + create_op=CreateOp.FULL, fill_value=fill_value, tensor_properties=tensor_properties) + return ShardedTensor( + sharding_spec, + *size, + tensor_init_params=tensor_init_params, + process_group=process_group, + init_rrefs=init_rrefs, + ) + + +def init_from_local_shards( + local_shards: List[Shard], + *global_size, + process_group=None, + init_rrefs=False) -> ShardedTensor: + """ + Creates an :class:`ShardedTensor` from local shards and the global metadata. + Needs to be called on all ranks in an SPMD fashion. + + Args: + local_shards (List[:class `torch.distributed._shard.sharded_tensor.Shard`]): A list + of shards that represent the local shards on this rank. + global_size (int...): a list, tuple, or `torch.Size` of integers defining the + shape of the overall sharded tensor. + + Keyword args: + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object handle on this rank + + + Examples: + Suppose we want construct a sharded tensor on two ranks, global size = (10, 5), + each shard have a (5, 5) local tensor, we can do it like below: + + on rank 0: + >>> local_shard_metadata = ShardMetadata( + >>> shard_offsets=[0, 0] + >>> shard_lengths=[5, 5] + >>> placement="rank:0/cuda:0" + >>> ) + >>> local_shards = [Shard(torch.randn(5, 5), local_shard_metadata)] + >>> sharded_tensor = init_from_local_shards(local_shards, [10, 5]) + + on rank 1: + >>> local_shard_metadata = ShardMetadata( + >>> shard_offsets=[5, 0] + >>> shard_lengths=[5, 5] + >>> placement="rank:1/cuda:1" + >>> ) + >>> local_shards = [Shard(torch.randn(5, 5), local_shard_metadata)] + >>> sharded_tensor = init_from_local_shards(local_shards, [10, 5]) + """ + return ShardedTensor._init_from_local_shards( + local_shards, + *global_size, + process_group=process_group, + init_rrefs=init_rrefs + ) + +def state_dict_hook(module, destination, prefix, local_metadata): + """ + Hook to add ShardedTensor to Module's ``state_dict``. Needs to be + registered to the Module using + :meth:`torch.nn.Module._register_state_dict_hook`. + """ + for submodule_name, submodule in module.named_modules(): + for attr_name, attr in submodule.__dict__.items(): + if isinstance(attr, ShardedTensor): + destination[prefix + submodule_name + '.' + attr_name] = attr + +def pre_load_state_dict_hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + """ + Pre-load state dict hook to add ShardedTensor to the module. + """ + for submodule_name, submodule in module.named_modules(): + for attr_name, attr in submodule.__dict__.items(): + key = prefix + submodule_name + '.' + attr_name + if key in state_dict: + if isinstance(state_dict[key], ShardedTensor): + setattr(submodule, attr_name, state_dict[key]) + +def sharded_op_impl(func): + """ + Provides a way for users to write their own custom sharded operator. This + can be used to override existing ShardedTensor operators or write a new + one not supported by ShardedTensor. If the operator in question is covered + by ``__torch_function__`` dispatch and has a ShardedTensor as any of its + parameters, the function provided will be invoked for that operator. + + Example:: + >>> @custom_sharded_op(torch.nn.functional.linear) + >>> def my_custom_sharded_linear(types, args, kwargs, process_group): + >>> .... + >>> + >>> input = torch.rand(10, 32) + >>> weight = sharded_tensor.rand(32, 16) + >>> bias = torch.rand(16) + >>> # This will call 'my_custom_sharded_linear' + >>> torch.nn.functional.linear(input, weight, bias) + + The types, args and kwargs parameters are the same parameters that are + passed to ``__torch_function__`` dispatch API + (https://pytorch.org/docs/stable/notes/extending.html#extending-torch). + There is an additional ``process_group`` parameter which is the + process_group used for the ShardedTensor and can be used by + implementations for communications within a sharded implementation. + + Args: + func(Callable): Torch function for which we want to provide a sharded + implementation (ex: torch.nn.functional.linear) + """ + def decorator_sharded_func(wrapped_func): + _register_sharded_op(func, wrapped_func) + + @functools.wraps(wrapped_func) + def wrapper(*args, **kwargs): + return wrapped_func(*args, **kwargs) + return wrapper + return decorator_sharded_func + +# Import all builtin sharded ops +from ._ops import * # noqa: F403 diff --git a/torch/distributed/_sharded_tensor/ops/__init__.py b/torch/distributed/_shard/sharded_tensor/_ops/__init__.py similarity index 100% rename from torch/distributed/_sharded_tensor/ops/__init__.py rename to torch/distributed/_shard/sharded_tensor/_ops/__init__.py diff --git a/torch/distributed/_sharded_tensor/ops/_common.py b/torch/distributed/_shard/sharded_tensor/_ops/_common.py similarity index 99% rename from torch/distributed/_sharded_tensor/ops/_common.py rename to torch/distributed/_shard/sharded_tensor/_ops/_common.py index 84d8bafdfb7..79adab8f529 100644 --- a/torch/distributed/_sharded_tensor/ops/_common.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/_common.py @@ -4,7 +4,7 @@ from typing import List import torch import torch.distributed as dist -from torch.distributed._sharding_spec._internals import ( +from torch.distributed._shard.sharding_spec._internals import ( get_split_size, get_chunked_dim_size, ) diff --git a/torch/distributed/_sharded_tensor/ops/binary_cmp.py b/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py similarity index 97% rename from torch/distributed/_sharded_tensor/ops/binary_cmp.py rename to torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py index 2a0be14739d..cdee16d1890 100644 --- a/torch/distributed/_sharded_tensor/ops/binary_cmp.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py @@ -1,7 +1,7 @@ import torch import torch.distributed as dist import torch.distributed.distributed_c10d as distributed_c10d -from torch.distributed._sharded_tensor import ( +from torch.distributed._shard.sharded_tensor import ( ShardedTensor, sharded_op_impl ) diff --git a/torch/distributed/_sharded_tensor/ops/embedding.py b/torch/distributed/_shard/sharded_tensor/_ops/embedding.py similarity index 98% rename from torch/distributed/_sharded_tensor/ops/embedding.py rename to torch/distributed/_shard/sharded_tensor/_ops/embedding.py index 563c250b327..2fa1e7dd2d5 100644 --- a/torch/distributed/_sharded_tensor/ops/embedding.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/embedding.py @@ -4,14 +4,14 @@ from typing import cast import torch import torch.distributed as dist -from torch.distributed._sharded_tensor.ops._common import ( +from ._common import ( _communicate_size_to_each_rank, _handle_col_wise_sharding_base, _handle_row_wise_lookup_distribute, _handle_max_norm_col_wise, ) -from torch.distributed._sharding_spec import ChunkShardingSpec -from torch.distributed._sharded_tensor import ( +from torch.distributed._shard.sharding_spec import ChunkShardingSpec +from torch.distributed._shard.sharded_tensor import ( sharded_op_impl, ShardedTensor ) diff --git a/torch/distributed/_sharded_tensor/ops/embedding_bag.py b/torch/distributed/_shard/sharded_tensor/_ops/embedding_bag.py similarity index 99% rename from torch/distributed/_sharded_tensor/ops/embedding_bag.py rename to torch/distributed/_shard/sharded_tensor/_ops/embedding_bag.py index 40fb2320c3f..2de9f273862 100644 --- a/torch/distributed/_sharded_tensor/ops/embedding_bag.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/embedding_bag.py @@ -7,15 +7,15 @@ import torch.distributed as dist from torch._C._distributed_c10d import ( ReduceOp, ) -from torch.distributed._sharded_tensor.ops._common import ( +from ._common import ( _communicate_list_to_each_rank, _communicate_size_to_each_rank, _handle_col_wise_sharding_base, _handle_row_wise_lookup_distribute, _handle_max_norm_col_wise, ) -from torch.distributed._sharding_spec import ChunkShardingSpec -from torch.distributed._sharded_tensor import ( +from torch.distributed._shard.sharding_spec import ChunkShardingSpec +from torch.distributed._shard.sharded_tensor import ( sharded_op_impl, ShardedTensor ) diff --git a/torch/distributed/_sharded_tensor/ops/init.py b/torch/distributed/_shard/sharded_tensor/_ops/init.py similarity index 98% rename from torch/distributed/_sharded_tensor/ops/init.py rename to torch/distributed/_shard/sharded_tensor/_ops/init.py index 3b6b305c3bb..4768a10bfcf 100644 --- a/torch/distributed/_sharded_tensor/ops/init.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/init.py @@ -1,5 +1,5 @@ import torch -from torch.distributed._sharded_tensor import ( +from torch.distributed._shard.sharded_tensor import ( sharded_op_impl, ) diff --git a/torch/distributed/_sharded_tensor/ops/linear.py b/torch/distributed/_shard/sharded_tensor/_ops/linear.py similarity index 97% rename from torch/distributed/_sharded_tensor/ops/linear.py rename to torch/distributed/_shard/sharded_tensor/_ops/linear.py index 6ecec716352..e8255c20c33 100644 --- a/torch/distributed/_sharded_tensor/ops/linear.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/linear.py @@ -2,11 +2,11 @@ from typing import List, cast import torch import torch.distributed as dist -from torch.distributed._sharded_tensor.ops._common import ( +from ._common import ( _handle_col_wise_sharding_base, ) -from torch.distributed._sharding_spec import ChunkShardingSpec -from torch.distributed._sharding_spec._internals import ( +from torch.distributed._shard.sharding_spec import ChunkShardingSpec +from torch.distributed._shard.sharding_spec._internals import ( get_split_size, get_chunked_dim_size, ) @@ -15,7 +15,7 @@ from torch.distributed.nn.functional import ( reduce_scatter, ) -from torch.distributed._sharded_tensor import ( +from torch.distributed._shard.sharded_tensor import ( sharded_op_impl, ShardedTensor ) diff --git a/torch/distributed/_sharded_tensor/api.py b/torch/distributed/_shard/sharded_tensor/api.py similarity index 98% rename from torch/distributed/_sharded_tensor/api.py rename to torch/distributed/_shard/sharded_tensor/api.py index 790b18c838d..da4ef7ba65f 100644 --- a/torch/distributed/_sharded_tensor/api.py +++ b/torch/distributed/_shard/sharded_tensor/api.py @@ -14,13 +14,13 @@ import torch import torch.distributed as dist from torch.distributed import rpc from torch.distributed import distributed_c10d -from torch.distributed._sharding_spec import ( +from torch.distributed._shard.sharding_spec import ( ChunkShardingSpec, EnumerableShardingSpec, ShardMetadata, ShardingSpec, ) -from torch.distributed._sharding_spec._internals import ( +from torch.distributed._shard.sharding_spec._internals import ( check_tensor, get_split_size, get_chunked_dim_size, @@ -108,13 +108,13 @@ class ShardedTensor(object): ShardedTensor doesn't provide any Tensor like operations but is a wrapper providing the Tensor representing the local shard and the global metadata. - Using these, users can build their custom distributed sharded computations + Using these, users can build their custom distributed._sharded computations on top of this primitive. The local shards are all initialized using the create_op specified by tensor_init_params.create_op, e.g., torch.ones, or torch.empty Args: - sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification describing how to shard the Tensor. size (int...): a sequence of integers defining the shape of the output tensor. Can be a variable number of arguments or a collection like a list or tuple. @@ -138,7 +138,7 @@ class ShardedTensor(object): def __new__(cls, *args, **kwargs): # Use __new__ for logging purposes. - torch._C._log_api_usage_once("torch.distributed.sharded_tensor") + torch._C._log_api_usage_once("torch.distributed._shard.sharded_tensor") return super(ShardedTensor, cls).__new__(cls) def __init__( diff --git a/torch/distributed/_sharded_tensor/metadata.py b/torch/distributed/_shard/sharded_tensor/metadata.py similarity index 97% rename from torch/distributed/_sharded_tensor/metadata.py rename to torch/distributed/_shard/sharded_tensor/metadata.py index 057461ad07a..b9a43400ef2 100644 --- a/torch/distributed/_sharded_tensor/metadata.py +++ b/torch/distributed/_shard/sharded_tensor/metadata.py @@ -3,7 +3,7 @@ from enum import Enum from typing import List import torch -from torch.distributed._sharding_spec import ShardMetadata +from torch.distributed._shard.sharding_spec import ShardMetadata class MEM_FORMAT_ENCODING(Enum): diff --git a/torch/distributed/_sharded_tensor/shard.py b/torch/distributed/_shard/sharded_tensor/shard.py similarity index 93% rename from torch/distributed/_sharded_tensor/shard.py rename to torch/distributed/_shard/sharded_tensor/shard.py index 08e9767f46b..7436cfa3388 100644 --- a/torch/distributed/_sharded_tensor/shard.py +++ b/torch/distributed/_shard/sharded_tensor/shard.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from typing import List, cast import torch -from torch.distributed._sharding_spec import ShardMetadata +from torch.distributed._shard.sharding_spec import ShardMetadata from torch.distributed.remote_device import _remote_device @@ -14,7 +14,7 @@ class Shard(object): Args: tensor(torch.Tensor): Local tensor for the shard. - metadata(:class `torch.distributed._sharded_tensor.ShardMetadata`): + metadata(:class `torch.distributed._shard.sharded_tensor.ShardMetadata`): The metadata for the shard, including offsets, lengths and device placement. """ __slots__ = ['tensor', 'metadata'] diff --git a/torch/distributed/_sharded_tensor/utils.py b/torch/distributed/_shard/sharded_tensor/utils.py similarity index 98% rename from torch/distributed/_sharded_tensor/utils.py rename to torch/distributed/_shard/sharded_tensor/utils.py index 773a199ae52..98fa1140c44 100644 --- a/torch/distributed/_sharded_tensor/utils.py +++ b/torch/distributed/_shard/sharded_tensor/utils.py @@ -5,10 +5,10 @@ from typing import Optional, List, Sequence import torch from torch.distributed import distributed_c10d from torch.distributed import rpc -from torch.distributed._sharding_spec import ( +from torch.distributed._shard.sharding_spec import ( ShardMetadata, ) -from torch.distributed._sharding_spec._internals import ( +from torch.distributed._shard.sharding_spec._internals import ( check_tensor, validate_non_overlapping_shards_metadata, ) diff --git a/torch/distributed/_shard/sharding_spec/__init__.py b/torch/distributed/_shard/sharding_spec/__init__.py new file mode 100644 index 00000000000..f25c849559d --- /dev/null +++ b/torch/distributed/_shard/sharding_spec/__init__.py @@ -0,0 +1,8 @@ +from .api import ( + ChunkShardingSpec, + DevicePlacementSpec, + EnumerableShardingSpec, + PlacementSpec, + ShardMetadata, + ShardingSpec, +) diff --git a/torch/distributed/_sharding_spec/_internals.py b/torch/distributed/_shard/sharding_spec/_internals.py similarity index 100% rename from torch/distributed/_sharding_spec/_internals.py rename to torch/distributed/_shard/sharding_spec/_internals.py diff --git a/torch/distributed/_sharding_spec/api.py b/torch/distributed/_shard/sharding_spec/api.py similarity index 100% rename from torch/distributed/_sharding_spec/api.py rename to torch/distributed/_shard/sharding_spec/api.py diff --git a/torch/distributed/_sharded_tensor/__init__.py b/torch/distributed/_sharded_tensor/__init__.py index d6cef1046ac..9e6b1662589 100644 --- a/torch/distributed/_sharded_tensor/__init__.py +++ b/torch/distributed/_sharded_tensor/__init__.py @@ -1,540 +1,12 @@ -# coding=utf-8 - -import copy -import functools +# Keep old package for BC purposes, this file should be removed once +# everything moves to the `torch.distributed._shard` package. +import sys import torch -from torch.distributed._sharding_spec import ( - ChunkShardingSpec, - ShardingSpec, +import warnings + +from torch.distributed._shard.sharded_tensor import * # noqa: F403 +warnings.warn( + "torch.distributed._sharded_tensor will be deprecated, use torch.distributed._shard.sharded_tensor instead", + DeprecationWarning ) -from torch.distributed._sharding_spec._internals import ( - get_chunked_dim_size, - get_split_size, -) -from typing import List - -from .api import ( - _register_sharded_op, - CreateOp, - Shard, - ShardMetadata, - ShardedTensor, - ShardedTensorMetadata, - TensorInitParams, - TensorProperties, -) -from .utils import load_with_process_group -import torch.distributed as dist -from torch.distributed import distributed_c10d - - -def empty(sharding_spec: ShardingSpec, - *size, - dtype=None, - layout=torch.strided, - requires_grad=False, - pin_memory=False, - memory_format=torch.contiguous_format, - process_group=None, - init_rrefs=False) -> ShardedTensor: - """ - Returns a :class:`ShardedTensor` filled with uninitialized data. - Needs to be called on all ranks in an SPMD fashion. - - Args: - sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification - describing how to shard the Tensor. - size (int...): a sequence of integers defining the shape of the output - tensor. Can be a variable number of arguments or a collection like a list or tuple. - - Keyword args: - dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. - Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`). - layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. - Default: ``torch.strided``. - requires_grad (bool, optional): If autograd should record operations on the - returned tensor. Default: ``False``. - pin_memory (bool, optional): If set, returned tensor would be allocated in - the pinned memory. Works only for CPU tensors. Default: ``False``. - memory_format (:class:`torch.memory_format`, optional): the desired memory format of - returned Tensor. Default: ``torch.contiguous_format``. - process_group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. - init_rrefs (bool, optional): Whether or not to initialize - :class:`torch.distributed.rpc.RRef`s pointing to remote shards. - Need to initialize the RPC Framework if specified as ``True``. - Default: ``False``. - - Returns: - A :class:`ShardedTensor` object on each rank - """ - tensor_properties = TensorProperties(dtype=dtype, layout=layout, - requires_grad=requires_grad, - pin_memory=pin_memory, memory_format=memory_format, ) - tensor_init_params = TensorInitParams(create_op=CreateOp.EMPTY, tensor_properties=tensor_properties, ) - return ShardedTensor( - sharding_spec, - *size, - tensor_init_params=tensor_init_params, - process_group=process_group, - init_rrefs=init_rrefs, - ) - -def ones(sharding_spec: ShardingSpec, - *size, - dtype=None, - layout=torch.strided, - requires_grad=False, - pin_memory=False, - memory_format=torch.contiguous_format, - process_group=None, - init_rrefs=False) -> ShardedTensor: - """ - Returns a :class:`ShardedTensor` with the scalar value 1. - Needs to be called on all ranks in an SPMD fashion. - - Args: - sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification - describing how to shard the Tensor. - size (int...): a sequence of integers defining the shape of the output - tensor. Can be a variable number of arguments or a collection like a list or tuple. - - Keyword args: - dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. - Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`). - layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. - Default: ``torch.strided``. - requires_grad (bool, optional): If autograd should record operations on the - returned tensor. Default: ``False``. - pin_memory (bool, optional): If set, returned tensor would be allocated in - the pinned memory. Works only for CPU tensors. Default: ``False``. - process_group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. - init_rrefs (bool, optional): Whether or not to initialize - :class:`torch.distributed.rpc.RRef`s pointing to remote shards. - Need to initialize the RPC Framework if specified as ``True``. - Default: ``False``. - - Returns: - A :class:`ShardedTensor` object on each rank - """ - tensor_properties = TensorProperties(dtype=dtype, layout=layout, - requires_grad=requires_grad, - pin_memory=pin_memory, memory_format=memory_format, ) - tensor_init_params = TensorInitParams(create_op=CreateOp.ONES, tensor_properties=tensor_properties) - return ShardedTensor( - sharding_spec, - *size, - tensor_init_params=tensor_init_params, - process_group=process_group, - init_rrefs=init_rrefs, - ) - - -def rand(sharding_spec: ShardingSpec, - *size, - dtype=None, - layout=torch.strided, - requires_grad=False, - pin_memory=False, - memory_format=torch.contiguous_format, - process_group=None, - init_rrefs=False) -> ShardedTensor: - """ - Returns a :class:`ShardedTensor` filled with random numbers from a uniform distribution on the - interval :math:`[0, 1)`. Needs to be called on all ranks in an SPMD fashion. - - Args: - sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification - describing how to shard the Tensor. - size (int...): a sequence of integers defining the shape of the output - tensor. Can be a variable number of arguments or a collection like a list or tuple. - - Keyword args: - dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. - Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`). - layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. - Default: ``torch.strided``. - requires_grad (bool, optional): If autograd should record operations on the - returned tensor. Default: ``False``. - pin_memory (bool, optional): If set, returned tensor would be allocated in - the pinned memory. Works only for CPU tensors. Default: ``False``. - process_group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. - init_rrefs (bool, optional): Whether or not to initialize - :class:`torch.distributed.rpc.RRef`s pointing to remote shards. - Need to initialize the RPC Framework if specified as ``True``. - Default: ``False``. - - Returns: - A :class:`ShardedTensor` object on each rank - """ - tensor_properties = TensorProperties( - dtype=dtype, layout=layout, requires_grad=requires_grad, - pin_memory=pin_memory, memory_format=memory_format - ) - tensor_init_params = TensorInitParams(create_op=CreateOp.RAND, tensor_properties=tensor_properties, ) - return ShardedTensor( - sharding_spec, - *size, - tensor_init_params=tensor_init_params, - process_group=process_group, - init_rrefs=init_rrefs, - ) - - -def zeros(sharding_spec: ShardingSpec, - *size, - dtype=None, - layout=torch.strided, - requires_grad=False, - pin_memory=False, - memory_format=torch.contiguous_format, - process_group=None, - init_rrefs=False) -> ShardedTensor: - """ - Returns a :class:`ShardedTensor` filled with the scalar value 0. - Needs to be called on all ranks in an SPMD fashion. - - Args: - sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification - describing how to shard the Tensor. - size (int...): a sequence of integers defining the shape of the output - tensor. Can be a variable number of arguments or a collection like a list or tuple. - - Keyword args: - dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. - Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`). - layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. - Default: ``torch.strided``. - requires_grad (bool, optional): If autograd should record operations on the - returned tensor. Default: ``False``. - pin_memory (bool, optional): If set, returned tensor would be allocated in - the pinned memory. Works only for CPU tensors. Default: ``False``. - process_group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. - init_rrefs (bool, optional): Whether or not to initialize - :class:`torch.distributed.rpc.RRef`s pointing to remote shards. - Need to initialize the RPC Framework if specified as ``True``. - Default: ``False``. - - Returns: - A :class:`ShardedTensor` object on each rank - """ - tensor_properties = TensorProperties( - dtype=dtype, layout=layout, requires_grad=requires_grad, - pin_memory=pin_memory, memory_format=memory_format, - ) - tensor_init_params = TensorInitParams(create_op=CreateOp.ZEROS, tensor_properties=tensor_properties, ) - return ShardedTensor( - sharding_spec, - *size, - tensor_init_params=tensor_init_params, - process_group=process_group, - init_rrefs=init_rrefs, - ) - - -def full(sharding_spec: ShardingSpec, - size, - fill_value=torch.types.Number, - dtype=None, - layout=torch.strided, - requires_grad=False, - pin_memory=False, - memory_format=torch.contiguous_format, - process_group=None, - init_rrefs=False) -> ShardedTensor: - """ - Creates a :class:`ShardedTensor` filled with fill_value. The tensor’s dtype - is inferred from fill_value. If dtype is specified, it will override the - inferred type from fill_value. Needs to be called on all ranks in an SPMD fashion. - - Args: - sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification - describing how to shard the Tensor. - size (int...): a list, tuple, or `torch.Size` of integers defining the shape of the - output tensor. - fill_value (Scalar) – the value to fill the output tensor with. - - Keyword args: - dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. - Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`). - layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. - Default: ``torch.strided``. - requires_grad (bool, optional): If autograd should record operations on the - returned tensor. Default: ``False``. - pin_memory (bool, optional): If set, returned tensor would be allocated in - the pinned memory. Works only for CPU tensors. Default: ``False``. - process_group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. - init_rrefs (bool, optional): Whether or not to initialize - :class:`torch.distributed.rpc.RRef`s pointing to remote shards. - Need to initialize the RPC Framework if specified as ``True``. - Default: ``False``. - - Returns: - A :class:`ShardedTensor` object on each rank - """ - tensor_properties = TensorProperties( - dtype=dtype, layout=layout, requires_grad=requires_grad, - pin_memory=pin_memory, memory_format=memory_format, - ) - tensor_init_params = TensorInitParams( - create_op=CreateOp.FULL, fill_value=fill_value, tensor_properties=tensor_properties) - return ShardedTensor( - sharding_spec, - *size, - tensor_init_params=tensor_init_params, - process_group=process_group, - init_rrefs=init_rrefs, - ) - - -def init_from_local_shards( - local_shards: List[Shard], - *global_size, - process_group=None, - init_rrefs=False) -> ShardedTensor: - """ - Creates an :class:`ShardedTensor` from local shards and the global metadata. - Needs to be called on all ranks in an SPMD fashion. - - Args: - local_shards (List[:class `torch.distributed._sharded_tensor.Shard`]): A list - of shards that represent the local shards on this rank. - global_size (int...): a list, tuple, or `torch.Size` of integers defining the - shape of the overall sharded tensor. - - Keyword args: - process_group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. - init_rrefs (bool, optional): Whether or not to initialize - :class:`torch.distributed.rpc.RRef`s pointing to remote shards. - Need to initialize the RPC Framework if specified as ``True``. - Default: ``False``. - - Returns: - A :class:`ShardedTensor` object handle on this rank - - - Examples: - Suppose we want construct a sharded tensor on two ranks, global size = (10, 5), - each shard have a (5, 5) local tensor, we can do it like below: - - on rank 0: - >>> local_shard_metadata = ShardMetadata( - >>> shard_offsets=[0, 0] - >>> shard_lengths=[5, 5] - >>> placement="rank:0/cuda:0" - >>> ) - >>> local_shards = [Shard(torch.randn(5, 5), local_shard_metadata)] - >>> sharded_tensor = init_from_local_shards(local_shards, [10, 5]) - - on rank 1: - >>> local_shard_metadata = ShardMetadata( - >>> shard_offsets=[5, 0] - >>> shard_lengths=[5, 5] - >>> placement="rank:1/cuda:1" - >>> ) - >>> local_shards = [Shard(torch.randn(5, 5), local_shard_metadata)] - >>> sharded_tensor = init_from_local_shards(local_shards, [10, 5]) - """ - return ShardedTensor._init_from_local_shards( - local_shards, - *global_size, - process_group=process_group, - init_rrefs=init_rrefs - ) - -def state_dict_hook(module, destination, prefix, local_metadata): - """ - Hook to add ShardedTensor to Module's ``state_dict``. Needs to be - registered to the Module using - :meth:`torch.nn.Module._register_state_dict_hook`. - """ - for submodule_name, submodule in module.named_modules(): - for attr_name, attr in submodule.__dict__.items(): - if isinstance(attr, ShardedTensor): - destination[prefix + submodule_name + '.' + attr_name] = attr - -def pre_load_state_dict_hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - """ - Pre-load state dict hook to add ShardedTensor to the module. - """ - for submodule_name, submodule in module.named_modules(): - for attr_name, attr in submodule.__dict__.items(): - key = prefix + submodule_name + '.' + attr_name - if key in state_dict: - if isinstance(state_dict[key], ShardedTensor): - setattr(submodule, attr_name, state_dict[key]) - -def shard_parameter( - module: torch.nn.Module, - param_name: str, - sharding_spec: ShardingSpec, - src_rank=0, - process_group=None): - """ - Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that - module, it shards that parameter according to the provided - ``sharding_spec``. ``src_rank`` denotes the source rank which would be - used as the ground truth of the data which would be scattered as shards - across the rest of the ranks. - - This method replaces ``module.param_name`` with a - :class:`torch.distributed._sharded_tensor.ShardedTensor` - - Args: - module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded. - param_name (str): Name of the parameter of ``module`` that needs to be sharded. - sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification - describing how to shard the Tensor. - - Keyword args: - src_rank (int, optional): The source rank which is used as the ground truth of - the data for the parameter that would be sharded and scattered - across the rest of the ranks. - Default: 0. - process_group (ProcessGroup, optional): The process group to work on. If None, - the default process group will be used. - - .. warning:: - Only :class:`torch.distributed._sharding_spec.ShardingSpec` is - currently supported as the ``sharding_spec``. - """ - # Perform some validation first. - if not isinstance(sharding_spec, ChunkShardingSpec): - raise ValueError('Only ChunkShardingspec is supported.') - - if not hasattr(module, param_name): - raise ValueError(f'module: {module} does not have parameter with name: {param_name}') - - tensor = getattr(module, param_name) - if not isinstance(tensor, torch.Tensor): - raise ValueError(f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}') - - if not tensor.is_contiguous(): - raise ValueError(f'param: {param_name} is not a contiguous Tensor') - - pg = process_group if process_group is not None else distributed_c10d._get_default_group() - world_size = dist.get_world_size(pg) - rank = dist.get_rank(pg) - - # Validate src_rank and sharding_spec are same across all ranks. - gathered_list = [None] * world_size - dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg) - - for idx, entry in enumerate(gathered_list): - if src_rank != entry[0]: # type: ignore[index] - raise ValueError( - f'src_rank={src_rank} on rank: {rank} does not ' # type: ignore[index] - f'match with src_rank={entry[0]} on rank: {idx}') - if sharding_spec != entry[1]: # type: ignore[index] - raise ValueError( - f'sharding_spec={sharding_spec} on rank: {rank} does not ' # type: ignore[index] - f'match with sharding_spec={entry[1]} on rank: {idx}') - - # Rearrange chunks according to placement. - local_metadata = None - current_offsets = [0] * len(tensor.size()) - shards_metadata = [] - sharding_dim_size = tensor.size(sharding_spec.dim) # type: ignore[arg-type] - split_size = get_split_size(sharding_dim_size, world_size) - tensor_sizes = list(tensor.size()) - for idx, placement in enumerate(sharding_spec.placements): - chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size, idx) - shard_size = copy.deepcopy(tensor_sizes) - shard_size[sharding_spec.dim] = chunked_dim_size # type: ignore[index] - - shard_metadata = ShardMetadata( - shard_offsets=copy.deepcopy(current_offsets), - shard_sizes=shard_size, - placement=placement, - ) - shards_metadata.append(shard_metadata) - - if rank == placement.rank(): # type: ignore[union-attr] - local_metadata = shard_metadata - - current_offsets[sharding_spec.dim] += chunked_dim_size # type: ignore[index] - - # Scatter the shards (use broadcast since NCCL doesn't support scatter, this is very inefficient). - dist.broadcast(tensor, src=src_rank, group=pg) - - # Reshape to get shard for this rank and we don't want autograd - # recording here for the narrow op and 'local_shard' should be a - # leaf variable in the autograd graph. - local_shard = tensor.narrow( - sharding_spec.dim, # type: ignore[arg-type] - local_metadata.shard_offsets[sharding_spec.dim], # type: ignore[union-attr, arg-type, index] - local_metadata.shard_sizes[sharding_spec.dim], # type: ignore[union-attr, index] - ).clone().detach().contiguous() - - # Sync requires_grad to local_shard. - local_shard.requires_grad = tensor.requires_grad - - # Create ShardedTensor based on local shards. - local_shards = [ - Shard( - tensor=local_shard, - metadata=local_metadata, # type: ignore[arg-type] - ) - ] - - st = ShardedTensor._init_from_local_shards(local_shards, tensor.size(), process_group=pg) - - # Manually set sharding_spec - st._sharding_spec = sharding_spec - - # Replace param with ShardedTensor. - - # Need to delete the attribute first since param_name might be - # torch.nn.Parameter and can't be replaced with ShardedTensor which is - # not torch.nn.Parameter. - delattr(module, param_name) - - # Now we can set the attribute appropriately. - setattr(module, param_name, st) - -def sharded_op_impl(func): - """ - Provides a way for users to write their own custom sharded operator. This - can be used to override existing ShardedTensor operators or write a new - one not supported by ShardedTensor. If the operator in question is covered - by ``__torch_function__`` dispatch and has a ShardedTensor as any of its - parameters, the function provided will be invoked for that operator. - - Example:: - >>> @custom_sharded_op(torch.nn.functional.linear) - >>> def my_custom_sharded_linear(types, args, kwargs, process_group): - >>> .... - >>> - >>> input = torch.rand(10, 32) - >>> weight = _sharded_tensor.rand(32, 16) - >>> bias = torch.rand(16) - >>> # This will call 'my_custom_sharded_linear' - >>> torch.nn.functional.linear(input, weight, bias) - - The types, args and kwargs parameters are the same parameters that are - passed to ``__torch_function__`` dispatch API - (https://pytorch.org/docs/stable/notes/extending.html#extending-torch). - There is an additional ``process_group`` parameter which is the - process_group used for the ShardedTensor and can be used by - implementations for communications within a sharded implementation. - - Args: - func(Callable): Torch function for which we want to provide a sharded - implementation (ex: torch.nn.functional.linear) - """ - def decorator_sharded_func(wrapped_func): - _register_sharded_op(func, wrapped_func) - - @functools.wraps(wrapped_func) - def wrapper(*args, **kwargs): - return wrapped_func(*args, **kwargs) - return wrapper - return decorator_sharded_func - -# Import all builtin sharded ops -from .ops import * # noqa: F403 +sys.modules['torch.distributed._sharded_tensor'] = torch.distributed._shard.sharded_tensor diff --git a/torch/distributed/_sharding_spec/__init__.py b/torch/distributed/_sharding_spec/__init__.py index f25c849559d..11e9e9a3dee 100644 --- a/torch/distributed/_sharding_spec/__init__.py +++ b/torch/distributed/_sharding_spec/__init__.py @@ -1,8 +1,12 @@ -from .api import ( - ChunkShardingSpec, - DevicePlacementSpec, - EnumerableShardingSpec, - PlacementSpec, - ShardMetadata, - ShardingSpec, +# Keep old package for BC purposes, this file should be removed once +# everything moves to the `torch.distributed._shard` package. +import sys +import torch +import warnings + +from torch.distributed._shard.sharding_spec import * # noqa: F403 +warnings.warn( + "torch.distributed._sharding_spec will be deprecated, use torch.distributed._shard.sharding_spec instead", + DeprecationWarning ) +sys.modules['torch.distributed._sharding_spec'] = torch.distributed._shard.sharding_spec diff --git a/torch/testing/_internal/distributed/_shard/__init__.py b/torch/testing/_internal/distributed/_shard/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/torch/testing/_internal/distributed/_sharded_tensor/__init__.py b/torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py similarity index 100% rename from torch/testing/_internal/distributed/_sharded_tensor/__init__.py rename to torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py diff --git a/torch/testing/_internal/distributed/_sharded_tensor/_test_ops_common.py b/torch/testing/_internal/distributed/_shard/sharded_tensor/_test_ops_common.py similarity index 94% rename from torch/testing/_internal/distributed/_sharded_tensor/_test_ops_common.py rename to torch/testing/_internal/distributed/_shard/sharded_tensor/_test_ops_common.py index ea3b9f551d6..13ffa6b6b87 100644 --- a/torch/testing/_internal/distributed/_sharded_tensor/_test_ops_common.py +++ b/torch/testing/_internal/distributed/_shard/sharded_tensor/_test_ops_common.py @@ -1,7 +1,7 @@ -from torch.distributed._sharding_spec import ( +from torch.distributed._shard.sharding_spec import ( ChunkShardingSpec, ) -from torch.distributed._sharding_spec._internals import ( +from torch.distributed._shard.sharding_spec._internals import ( get_chunked_dim_size, get_split_size, )