Introduce ChunkShardingSpec as a model sharding specification. (#55728)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55728

Full design: https://github.com/pytorch/pytorch/issues/55207

This PR introduces ChunkShardingSpec (SingleShardingSpec in the design). Used
the name ChunkShardingSpec since it is very similar to `torch.chunk` in terms
of how a Tensor is split up and feels more clear compared to SingleShardingSpec.
ghstack-source-id: 129603318

Test Plan: waitforbuildbot

Reviewed By: SciPioneer

Differential Revision: D27694108

fbshipit-source-id: c8764abe6a4d5fc56d023fda29b74b5af2a73b49
This commit is contained in:
Pritam Damania 2021-05-23 16:03:22 -07:00 committed by Facebook GitHub Bot
parent c5a1f04367
commit 0d6fa1adc5
9 changed files with 264 additions and 46 deletions

View file

@ -0,0 +1,56 @@
import torch
from torch.testing._internal.common_utils import TestCase
from torch.distributed._sharding_spec import (
ChunkShardingSpec,
DevicePlacementSpec
)
class TestShardingSpec(TestCase):
def test_device_placement(self):
# valid devices
DevicePlacementSpec("cuda:0")
DevicePlacementSpec(0)
DevicePlacementSpec(torch.device("cuda:0"))
DevicePlacementSpec("rank:0/cuda:0")
DevicePlacementSpec("rank:0/cpu")
DevicePlacementSpec("rank:0")
# invalid devices
with self.assertRaisesRegex(ValueError, "not a valid device"):
DevicePlacementSpec("cuda:foo")
with self.assertRaisesRegex(ValueError, "not a valid device"):
DevicePlacementSpec("foo:0")
with self.assertRaisesRegex(ValueError, "not a valid device"):
DevicePlacementSpec("rank:0/cuda:foo")
with self.assertRaisesRegex(ValueError, "not a valid device"):
DevicePlacementSpec("rank:0/cpu2")
def test_chunked_sharding_spec(self):
# Test valid specs.
ChunkShardingSpec(0, [0, 1])
# Named dimension.
ChunkShardingSpec("N", ["cuda:0", "cuda:1"])
ChunkShardingSpec(0, [torch.device("cuda:0"), torch.device("cuda:1")])
ChunkShardingSpec(-1, ["cuda:0", "cuda:1"])
ChunkShardingSpec(0, ["rank:0/cuda:0", "rank:0/cuda:1"])
ChunkShardingSpec(0, ["rank:0", "rank:1"])
ChunkShardingSpec(0, ["rank:0/cpu", "rank:1/cpu"])
# Test invalid specs
with self.assertRaisesRegex(ValueError, "int or str"):
ChunkShardingSpec(None, ["cuda:0", "cuda:1"])
with self.assertRaisesRegex(ValueError, "int or str"):
ChunkShardingSpec({}, ["cuda:0", "cuda:1"])
with self.assertRaisesRegex(ValueError, "not a valid device"):
ChunkShardingSpec(0, ["random:0", "cuda:1"])
with self.assertRaisesRegex(ValueError, "not a valid device"):
ChunkShardingSpec(0, ["cuda:foo", "cuda:1"])
with self.assertRaisesRegex(ValueError, "not a valid device"):
ChunkShardingSpec(0, ["rank:foo", "cuda:1"])
with self.assertRaisesRegex(ValueError, "not a valid device"):
ChunkShardingSpec(0, ["rank:0/foo", "cuda:1"])
with self.assertRaisesRegex(ValueError, "not a valid device"):
ChunkShardingSpec(0, ["rank:0/random:0", "cuda:1"])
with self.assertRaisesRegex(ValueError, "not a valid device"):
ChunkShardingSpec(0, ["rank:0/cuda:foo", "cuda:1"])

View file

@ -161,6 +161,7 @@ TESTS = [
'distributed/elastic/utils/util_test',
'distributed/elastic/utils/distributed_test',
'distributed/elastic/multiprocessing/api_test',
'distributed/_sharding_spec/test_sharding_spec',
]
# Tests need to be run with pytest.

View file

@ -0,0 +1,5 @@
from .api import (
PlacementSpec,
DevicePlacementSpec,
ChunkShardingSpec
)

View file

@ -0,0 +1,22 @@
import torch
from torch.distributed.utils import _parse_remote_device
def is_valid_device(device):
"""
Checks if this is a valid local/remote device.
"""
# Check for torch.device
try:
torch.device(device)
return True
except Exception:
pass
# Check for remote device.
try:
_parse_remote_device(device)
return True
except Exception:
pass
return False

View file

@ -0,0 +1,100 @@
from abc import ABC
import torch
from typing import List, Union
from ._internals import is_valid_device
Device = Union[torch.device, int, str]
class PlacementSpec(ABC):
"""
Base class representing the placement of an entity. Subclasses of this
class can be used to specify customized placements which might not be
covered by existing APIs.
"""
pass
class DevicePlacementSpec(PlacementSpec):
"""
Associates placement of an entity with a single device. The device can be a
local device or a remote device specified by one of the following remote
formats:
1. "rank:<rank>/<device>" (ex: "rank:0/cuda:0").
2. "<worker_name>/<device>" (ex: "trainer0/cuda:0").
Args:
device(str, :class:`torch.device`): The device to place the entity on.
"""
def __init__(self, device: Device):
super(DevicePlacementSpec, self).__init__()
if not is_valid_device(device):
raise ValueError(f'{device} is not a valid device')
self._device = device
@property
def device(self) -> Device:
"""
Retrieves the device for placement.
"""
return self._device
class ChunkShardingSpec(PlacementSpec):
"""
This is type of PlacementSpec that defines the placement as being sharded
across multiple devices. In particular, it represents sharding a Tensor
along a single dimension into equal chunks (similar to :meth:`torch.chunk`).
Args:
dim (int or str):
The dimension to shard on, could be an integer representing the
dimension or a string in case of named tensors where dimensions are
named.
placement(List[Device] or List[PlacementSpec]):
Specifies the placement of each shard of the Tensor. The size of
the list represents the number of shards to be created. This
parameter can be a list of devices
(ex: ["rank:0/cuda:0", "rank:1/cuda:1"]) or a list of custom
placement specs.
The device can be a local device or a remote device specified by one
of the following remote formats:
1. "rank:<rank>/<device>" (ex: "rank:0/cuda:0").
2. "<worker_name>/<device>" (ex: "trainer0/cuda:0").
"""
ShardPlacements = List[Union[Device, PlacementSpec]]
ShardingDim = Union[int, str]
def __init__(self, dim: ShardingDim, placements: ShardPlacements):
super(ChunkShardingSpec, self).__init__()
self._verify_dim(dim)
self._verify_devices(placements)
self._dim = dim
self._placements = placements
@staticmethod
def _verify_devices(placements):
for dev in placements:
if not isinstance(dev, PlacementSpec) and not is_valid_device(dev):
raise ValueError(f'{dev} is not a valid device')
@staticmethod
def _verify_dim(dim):
if not (isinstance(dim, int) or isinstance(dim, str)):
raise ValueError(f'{dim} needs to either be an int or str')
@property
def dim(self) -> ShardingDim:
"""
Retrieves the dimension to shard on.
"""
return self._dim
@property
def placements(self) -> ShardPlacements:
"""
Retrieves the shard placements.
"""
return self._placements

View file

@ -20,7 +20,7 @@ import torch.distributed.rpc as rpc
from torch import Tensor, device, dtype, nn
from torch.distributed.nn.jit import instantiator
from torch.distributed.rpc.internal import _internal_rpc_pickler
from torch.distributed.rpc.utils import _parse_remote_device
from torch.distributed.utils import _parse_remote_device
from torch.nn import Module
from torch.nn.parameter import Parameter
from torch.utils.hooks import RemovableHandle
@ -154,8 +154,12 @@ class _RemoteModule(nn.Module):
Args:
remote_device (str): Device on the destination worker where we'd like to place this module.
The format should be "<workername>/<device>", where the device field can be parsed as torch.device type.
E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0".
The device can be a local device or a remote device specified by one of the following remote
formats:
1. "rank:<rank>/<device>" (ex: "rank:0/cuda:0").
2. "<worker_name>/<device>" (ex: "trainer0/cuda:0").
In addition, the device field can be optional and the default value is "cpu".
module_cls (nn.Module): For example,
>>> class MyModule(nn.Module):

View file

@ -1,37 +0,0 @@
def _parse_remote_device(remote_device: str):
r"""
Parses the remote device.
Args:
remote_device (str): Device on the destination worker where we'd like to place this module.
The format should be "<workername>/<device>", where the device field can be parsed as torch.device type.
E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0".
In addition, the device field can be optional and the default value is "cpu".
Returns:
A workername and a device.
"""
fields = remote_device.split("/")
if len(fields) == 2:
[on, device] = fields
elif len(fields) == 1:
on = fields[0]
device = "cpu"
else:
raise RuntimeError(
"Could not parse remote_device: {}. The valid format is '<workername>/<device>'".format(
remote_device
)
)
# Since the workername in the input remote device won't be validated until the created remote module is executed,
# only do some very basic sanity check on workername at the module creation time.
# As currently there is no regex to describe the format of workername, just check whether the workername is empty.
if not on:
raise RuntimeError(
"The workername in remote_device '{}' cannot be empty. The valid format is '<workername>/<device>'".format(
remote_device
)
)
return on, device

View file

@ -0,0 +1,56 @@
import torch
def _parse_remote_device(remote_device: str):
r"""
Parses the remote device.
Args:
remote_device (str): Device on the destination worker where we'd like to place this module.
The format should be one of the following:
1. "<workername>/<device>", where the device field can be parsed as torch.device type.
E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0".
In addition, the device field can be optional and the default value is "cpu".
2. "rank:<rank>/<device>", where <rank> is the rank of the
process and device can be parsed as torch.device type.
E.g., "rank:0/cpu", "rank:0", "rank:0/cuda:0"
Returns:
A workername/rank and a device.
"""
PARSE_ERROR = (
f"Could not parse remote_device: {remote_device}. The valid format is "
"'<workername>/<device>' or 'rank:<rank>/<device>'"
)
fields = remote_device.split("/")
if len(fields) == 2:
[on, device] = fields
elif len(fields) == 1:
on = fields[0]
device = "cpu"
else:
raise ValueError(PARSE_ERROR)
# Since the workername in the input remote device won't be validated until the created remote module is executed,
# only do some very basic sanity check on workername at the module creation time.
# As currently there is no regex to describe the format of workername, just check whether the workername is empty.
if not on:
raise ValueError(PARSE_ERROR)
# Validate the device.
torch.device(device)
# Check for rank based format
fields = on.split(':')
if len(fields) == 2:
# rank:<rank>/device format, extract rank
if fields[0] == 'rank' and fields[1].isdigit():
on = int(fields[1]) # type: ignore[assignment]
else:
raise ValueError(PARSE_ERROR)
elif len(fields) > 2:
raise ValueError(PARSE_ERROR)
return on, device

View file

@ -554,7 +554,8 @@ class CudaRemoteModuleTest(CommonRemoteModuleTest):
def test_valid_device(self):
if self.rank != 0:
return
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
dst_rank = (self.rank + 1) % self.world_size
dst_worker_name = dist_utils.worker_name(dst_rank)
for remote_module in self._create_remote_module_iter(
"{}/cuda:0".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR]
@ -565,6 +566,16 @@ class CudaRemoteModuleTest(CommonRemoteModuleTest):
self.assertEqual(device.type, "cuda")
self.assertEqual(device.index, 0)
# Test rank works as well.
for remote_module in self._create_remote_module_iter(
"rank:{}/cuda:0".format(dst_rank), modes=[ModuleCreationMode.MODULE_CTOR]
):
device = rpc.rpc_sync(
dst_worker_name, remote_device, (remote_module.module_rref,)
)
self.assertEqual(device.type, "cuda")
self.assertEqual(device.index, 0)
@skip_if_lt_x_gpu(1)
@dist_utils.dist_init
def test_invalid_devices(self):
@ -614,7 +625,7 @@ class CudaRemoteModuleTest(CommonRemoteModuleTest):
)
with self.assertRaisesRegex(
RuntimeError,
ValueError,
r"Could not parse remote_device: worker1/cuda:0/cuda:1. The valid format is '<workername>/<device>'",
):
list(
@ -626,8 +637,8 @@ class CudaRemoteModuleTest(CommonRemoteModuleTest):
)
with self.assertRaisesRegex(
RuntimeError,
r"The workername in remote_device '/' cannot be empty. The valid format is '<workername>/<device>'",
ValueError,
r"Could not parse remote_device: /. The valid format is '<workername>/<device>'",
):
list(
m.forward() for m in
@ -638,8 +649,8 @@ class CudaRemoteModuleTest(CommonRemoteModuleTest):
)
with self.assertRaisesRegex(
RuntimeError,
r"The workername in remote_device '/cuda:0' cannot be empty. The valid format is '<workername>/<device>'",
ValueError,
r"Could not parse remote_device: /cuda:0. The valid format is '<workername>/<device>'",
):
list(
m.forward() for m in