diff --git a/test/distributed/_sharding_spec/test_sharding_spec.py b/test/distributed/_sharding_spec/test_sharding_spec.py new file mode 100644 index 00000000000..c96d28d8d2b --- /dev/null +++ b/test/distributed/_sharding_spec/test_sharding_spec.py @@ -0,0 +1,56 @@ +import torch +from torch.testing._internal.common_utils import TestCase +from torch.distributed._sharding_spec import ( + ChunkShardingSpec, + DevicePlacementSpec +) + +class TestShardingSpec(TestCase): + + def test_device_placement(self): + # valid devices + DevicePlacementSpec("cuda:0") + DevicePlacementSpec(0) + DevicePlacementSpec(torch.device("cuda:0")) + DevicePlacementSpec("rank:0/cuda:0") + DevicePlacementSpec("rank:0/cpu") + DevicePlacementSpec("rank:0") + + # invalid devices + with self.assertRaisesRegex(ValueError, "not a valid device"): + DevicePlacementSpec("cuda:foo") + with self.assertRaisesRegex(ValueError, "not a valid device"): + DevicePlacementSpec("foo:0") + with self.assertRaisesRegex(ValueError, "not a valid device"): + DevicePlacementSpec("rank:0/cuda:foo") + with self.assertRaisesRegex(ValueError, "not a valid device"): + DevicePlacementSpec("rank:0/cpu2") + + def test_chunked_sharding_spec(self): + # Test valid specs. + ChunkShardingSpec(0, [0, 1]) + # Named dimension. + ChunkShardingSpec("N", ["cuda:0", "cuda:1"]) + ChunkShardingSpec(0, [torch.device("cuda:0"), torch.device("cuda:1")]) + ChunkShardingSpec(-1, ["cuda:0", "cuda:1"]) + ChunkShardingSpec(0, ["rank:0/cuda:0", "rank:0/cuda:1"]) + ChunkShardingSpec(0, ["rank:0", "rank:1"]) + ChunkShardingSpec(0, ["rank:0/cpu", "rank:1/cpu"]) + + # Test invalid specs + with self.assertRaisesRegex(ValueError, "int or str"): + ChunkShardingSpec(None, ["cuda:0", "cuda:1"]) + with self.assertRaisesRegex(ValueError, "int or str"): + ChunkShardingSpec({}, ["cuda:0", "cuda:1"]) + with self.assertRaisesRegex(ValueError, "not a valid device"): + ChunkShardingSpec(0, ["random:0", "cuda:1"]) + with self.assertRaisesRegex(ValueError, "not a valid device"): + ChunkShardingSpec(0, ["cuda:foo", "cuda:1"]) + with self.assertRaisesRegex(ValueError, "not a valid device"): + ChunkShardingSpec(0, ["rank:foo", "cuda:1"]) + with self.assertRaisesRegex(ValueError, "not a valid device"): + ChunkShardingSpec(0, ["rank:0/foo", "cuda:1"]) + with self.assertRaisesRegex(ValueError, "not a valid device"): + ChunkShardingSpec(0, ["rank:0/random:0", "cuda:1"]) + with self.assertRaisesRegex(ValueError, "not a valid device"): + ChunkShardingSpec(0, ["rank:0/cuda:foo", "cuda:1"]) diff --git a/test/run_test.py b/test/run_test.py index 0314f7e6346..5b5970aef26 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -161,6 +161,7 @@ TESTS = [ 'distributed/elastic/utils/util_test', 'distributed/elastic/utils/distributed_test', 'distributed/elastic/multiprocessing/api_test', + 'distributed/_sharding_spec/test_sharding_spec', ] # Tests need to be run with pytest. diff --git a/torch/distributed/_sharding_spec/__init__.py b/torch/distributed/_sharding_spec/__init__.py new file mode 100644 index 00000000000..6c85b0bd28d --- /dev/null +++ b/torch/distributed/_sharding_spec/__init__.py @@ -0,0 +1,5 @@ +from .api import ( + PlacementSpec, + DevicePlacementSpec, + ChunkShardingSpec +) diff --git a/torch/distributed/_sharding_spec/_internals.py b/torch/distributed/_sharding_spec/_internals.py new file mode 100644 index 00000000000..86dc7087477 --- /dev/null +++ b/torch/distributed/_sharding_spec/_internals.py @@ -0,0 +1,22 @@ +import torch +from torch.distributed.utils import _parse_remote_device + +def is_valid_device(device): + """ + Checks if this is a valid local/remote device. + """ + # Check for torch.device + try: + torch.device(device) + return True + except Exception: + pass + + # Check for remote device. + try: + _parse_remote_device(device) + return True + except Exception: + pass + + return False diff --git a/torch/distributed/_sharding_spec/api.py b/torch/distributed/_sharding_spec/api.py new file mode 100644 index 00000000000..31b90002997 --- /dev/null +++ b/torch/distributed/_sharding_spec/api.py @@ -0,0 +1,100 @@ +from abc import ABC +import torch +from typing import List, Union + +from ._internals import is_valid_device + +Device = Union[torch.device, int, str] + +class PlacementSpec(ABC): + """ + Base class representing the placement of an entity. Subclasses of this + class can be used to specify customized placements which might not be + covered by existing APIs. + """ + pass + +class DevicePlacementSpec(PlacementSpec): + """ + Associates placement of an entity with a single device. The device can be a + local device or a remote device specified by one of the following remote + formats: + + 1. "rank:/" (ex: "rank:0/cuda:0"). + 2. "/" (ex: "trainer0/cuda:0"). + + Args: + device(str, :class:`torch.device`): The device to place the entity on. + """ + def __init__(self, device: Device): + super(DevicePlacementSpec, self).__init__() + if not is_valid_device(device): + raise ValueError(f'{device} is not a valid device') + self._device = device + + @property + def device(self) -> Device: + """ + Retrieves the device for placement. + """ + return self._device + +class ChunkShardingSpec(PlacementSpec): + """ + This is type of PlacementSpec that defines the placement as being sharded + across multiple devices. In particular, it represents sharding a Tensor + along a single dimension into equal chunks (similar to :meth:`torch.chunk`). + + Args: + dim (int or str): + The dimension to shard on, could be an integer representing the + dimension or a string in case of named tensors where dimensions are + named. + placement(List[Device] or List[PlacementSpec]): + Specifies the placement of each shard of the Tensor. The size of + the list represents the number of shards to be created. This + parameter can be a list of devices + (ex: ["rank:0/cuda:0", "rank:1/cuda:1"]) or a list of custom + placement specs. + + The device can be a local device or a remote device specified by one + of the following remote formats: + + 1. "rank:/" (ex: "rank:0/cuda:0"). + 2. "/" (ex: "trainer0/cuda:0"). + """ + + ShardPlacements = List[Union[Device, PlacementSpec]] + ShardingDim = Union[int, str] + + def __init__(self, dim: ShardingDim, placements: ShardPlacements): + super(ChunkShardingSpec, self).__init__() + self._verify_dim(dim) + self._verify_devices(placements) + self._dim = dim + self._placements = placements + + @staticmethod + def _verify_devices(placements): + for dev in placements: + if not isinstance(dev, PlacementSpec) and not is_valid_device(dev): + raise ValueError(f'{dev} is not a valid device') + + @staticmethod + def _verify_dim(dim): + if not (isinstance(dim, int) or isinstance(dim, str)): + raise ValueError(f'{dim} needs to either be an int or str') + + @property + def dim(self) -> ShardingDim: + """ + Retrieves the dimension to shard on. + """ + return self._dim + + @property + def placements(self) -> ShardPlacements: + """ + Retrieves the shard placements. + """ + return self._placements diff --git a/torch/distributed/nn/api/remote_module.py b/torch/distributed/nn/api/remote_module.py index 73bded145a8..3442271694b 100644 --- a/torch/distributed/nn/api/remote_module.py +++ b/torch/distributed/nn/api/remote_module.py @@ -20,7 +20,7 @@ import torch.distributed.rpc as rpc from torch import Tensor, device, dtype, nn from torch.distributed.nn.jit import instantiator from torch.distributed.rpc.internal import _internal_rpc_pickler -from torch.distributed.rpc.utils import _parse_remote_device +from torch.distributed.utils import _parse_remote_device from torch.nn import Module from torch.nn.parameter import Parameter from torch.utils.hooks import RemovableHandle @@ -154,8 +154,12 @@ class _RemoteModule(nn.Module): Args: remote_device (str): Device on the destination worker where we'd like to place this module. - The format should be "/", where the device field can be parsed as torch.device type. - E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0". + The device can be a local device or a remote device specified by one of the following remote + formats: + + 1. "rank:/" (ex: "rank:0/cuda:0"). + 2. "/" (ex: "trainer0/cuda:0"). + In addition, the device field can be optional and the default value is "cpu". module_cls (nn.Module): For example, >>> class MyModule(nn.Module): diff --git a/torch/distributed/rpc/utils.py b/torch/distributed/rpc/utils.py deleted file mode 100644 index afdde21f3c5..00000000000 --- a/torch/distributed/rpc/utils.py +++ /dev/null @@ -1,37 +0,0 @@ -def _parse_remote_device(remote_device: str): - r""" - Parses the remote device. - - Args: - remote_device (str): Device on the destination worker where we'd like to place this module. - The format should be "/", where the device field can be parsed as torch.device type. - E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0". - In addition, the device field can be optional and the default value is "cpu". - - Returns: - A workername and a device. - """ - fields = remote_device.split("/") - if len(fields) == 2: - [on, device] = fields - elif len(fields) == 1: - on = fields[0] - device = "cpu" - else: - raise RuntimeError( - "Could not parse remote_device: {}. The valid format is '/'".format( - remote_device - ) - ) - - # Since the workername in the input remote device won't be validated until the created remote module is executed, - # only do some very basic sanity check on workername at the module creation time. - # As currently there is no regex to describe the format of workername, just check whether the workername is empty. - if not on: - raise RuntimeError( - "The workername in remote_device '{}' cannot be empty. The valid format is '/'".format( - remote_device - ) - ) - - return on, device diff --git a/torch/distributed/utils.py b/torch/distributed/utils.py new file mode 100644 index 00000000000..323b3608023 --- /dev/null +++ b/torch/distributed/utils.py @@ -0,0 +1,56 @@ +import torch + +def _parse_remote_device(remote_device: str): + r""" + Parses the remote device. + + Args: + remote_device (str): Device on the destination worker where we'd like to place this module. + The format should be one of the following: + + 1. "/", where the device field can be parsed as torch.device type. + E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0". + In addition, the device field can be optional and the default value is "cpu". + 2. "rank:/", where is the rank of the + process and device can be parsed as torch.device type. + E.g., "rank:0/cpu", "rank:0", "rank:0/cuda:0" + + Returns: + A workername/rank and a device. + """ + + PARSE_ERROR = ( + f"Could not parse remote_device: {remote_device}. The valid format is " + "'/' or 'rank:/'" + ) + + fields = remote_device.split("/") + if len(fields) == 2: + [on, device] = fields + elif len(fields) == 1: + on = fields[0] + device = "cpu" + else: + raise ValueError(PARSE_ERROR) + + # Since the workername in the input remote device won't be validated until the created remote module is executed, + # only do some very basic sanity check on workername at the module creation time. + # As currently there is no regex to describe the format of workername, just check whether the workername is empty. + if not on: + raise ValueError(PARSE_ERROR) + + # Validate the device. + torch.device(device) + + # Check for rank based format + fields = on.split(':') + if len(fields) == 2: + # rank:/device format, extract rank + if fields[0] == 'rank' and fields[1].isdigit(): + on = int(fields[1]) # type: ignore[assignment] + else: + raise ValueError(PARSE_ERROR) + elif len(fields) > 2: + raise ValueError(PARSE_ERROR) + + return on, device diff --git a/torch/testing/_internal/distributed/nn/api/remote_module_test.py b/torch/testing/_internal/distributed/nn/api/remote_module_test.py index b4236f3224c..3b5a441b7a5 100644 --- a/torch/testing/_internal/distributed/nn/api/remote_module_test.py +++ b/torch/testing/_internal/distributed/nn/api/remote_module_test.py @@ -554,7 +554,8 @@ class CudaRemoteModuleTest(CommonRemoteModuleTest): def test_valid_device(self): if self.rank != 0: return - dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) + dst_rank = (self.rank + 1) % self.world_size + dst_worker_name = dist_utils.worker_name(dst_rank) for remote_module in self._create_remote_module_iter( "{}/cuda:0".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR] @@ -565,6 +566,16 @@ class CudaRemoteModuleTest(CommonRemoteModuleTest): self.assertEqual(device.type, "cuda") self.assertEqual(device.index, 0) + # Test rank works as well. + for remote_module in self._create_remote_module_iter( + "rank:{}/cuda:0".format(dst_rank), modes=[ModuleCreationMode.MODULE_CTOR] + ): + device = rpc.rpc_sync( + dst_worker_name, remote_device, (remote_module.module_rref,) + ) + self.assertEqual(device.type, "cuda") + self.assertEqual(device.index, 0) + @skip_if_lt_x_gpu(1) @dist_utils.dist_init def test_invalid_devices(self): @@ -614,7 +625,7 @@ class CudaRemoteModuleTest(CommonRemoteModuleTest): ) with self.assertRaisesRegex( - RuntimeError, + ValueError, r"Could not parse remote_device: worker1/cuda:0/cuda:1. The valid format is '/'", ): list( @@ -626,8 +637,8 @@ class CudaRemoteModuleTest(CommonRemoteModuleTest): ) with self.assertRaisesRegex( - RuntimeError, - r"The workername in remote_device '/' cannot be empty. The valid format is '/'", + ValueError, + r"Could not parse remote_device: /. The valid format is '/'", ): list( m.forward() for m in @@ -638,8 +649,8 @@ class CudaRemoteModuleTest(CommonRemoteModuleTest): ) with self.assertRaisesRegex( - RuntimeError, - r"The workername in remote_device '/cuda:0' cannot be empty. The valid format is '/'", + ValueError, + r"Could not parse remote_device: /cuda:0. The valid format is '/'", ): list( m.forward() for m in