mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Introduce ChunkShardingSpec as a model sharding specification. (#55728)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55728 Full design: https://github.com/pytorch/pytorch/issues/55207 This PR introduces ChunkShardingSpec (SingleShardingSpec in the design). Used the name ChunkShardingSpec since it is very similar to `torch.chunk` in terms of how a Tensor is split up and feels more clear compared to SingleShardingSpec. ghstack-source-id: 129603318 Test Plan: waitforbuildbot Reviewed By: SciPioneer Differential Revision: D27694108 fbshipit-source-id: c8764abe6a4d5fc56d023fda29b74b5af2a73b49
This commit is contained in:
parent
c5a1f04367
commit
0d6fa1adc5
9 changed files with 264 additions and 46 deletions
56
test/distributed/_sharding_spec/test_sharding_spec.py
Normal file
56
test/distributed/_sharding_spec/test_sharding_spec.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
from torch.distributed._sharding_spec import (
|
||||
ChunkShardingSpec,
|
||||
DevicePlacementSpec
|
||||
)
|
||||
|
||||
class TestShardingSpec(TestCase):
|
||||
|
||||
def test_device_placement(self):
|
||||
# valid devices
|
||||
DevicePlacementSpec("cuda:0")
|
||||
DevicePlacementSpec(0)
|
||||
DevicePlacementSpec(torch.device("cuda:0"))
|
||||
DevicePlacementSpec("rank:0/cuda:0")
|
||||
DevicePlacementSpec("rank:0/cpu")
|
||||
DevicePlacementSpec("rank:0")
|
||||
|
||||
# invalid devices
|
||||
with self.assertRaisesRegex(ValueError, "not a valid device"):
|
||||
DevicePlacementSpec("cuda:foo")
|
||||
with self.assertRaisesRegex(ValueError, "not a valid device"):
|
||||
DevicePlacementSpec("foo:0")
|
||||
with self.assertRaisesRegex(ValueError, "not a valid device"):
|
||||
DevicePlacementSpec("rank:0/cuda:foo")
|
||||
with self.assertRaisesRegex(ValueError, "not a valid device"):
|
||||
DevicePlacementSpec("rank:0/cpu2")
|
||||
|
||||
def test_chunked_sharding_spec(self):
|
||||
# Test valid specs.
|
||||
ChunkShardingSpec(0, [0, 1])
|
||||
# Named dimension.
|
||||
ChunkShardingSpec("N", ["cuda:0", "cuda:1"])
|
||||
ChunkShardingSpec(0, [torch.device("cuda:0"), torch.device("cuda:1")])
|
||||
ChunkShardingSpec(-1, ["cuda:0", "cuda:1"])
|
||||
ChunkShardingSpec(0, ["rank:0/cuda:0", "rank:0/cuda:1"])
|
||||
ChunkShardingSpec(0, ["rank:0", "rank:1"])
|
||||
ChunkShardingSpec(0, ["rank:0/cpu", "rank:1/cpu"])
|
||||
|
||||
# Test invalid specs
|
||||
with self.assertRaisesRegex(ValueError, "int or str"):
|
||||
ChunkShardingSpec(None, ["cuda:0", "cuda:1"])
|
||||
with self.assertRaisesRegex(ValueError, "int or str"):
|
||||
ChunkShardingSpec({}, ["cuda:0", "cuda:1"])
|
||||
with self.assertRaisesRegex(ValueError, "not a valid device"):
|
||||
ChunkShardingSpec(0, ["random:0", "cuda:1"])
|
||||
with self.assertRaisesRegex(ValueError, "not a valid device"):
|
||||
ChunkShardingSpec(0, ["cuda:foo", "cuda:1"])
|
||||
with self.assertRaisesRegex(ValueError, "not a valid device"):
|
||||
ChunkShardingSpec(0, ["rank:foo", "cuda:1"])
|
||||
with self.assertRaisesRegex(ValueError, "not a valid device"):
|
||||
ChunkShardingSpec(0, ["rank:0/foo", "cuda:1"])
|
||||
with self.assertRaisesRegex(ValueError, "not a valid device"):
|
||||
ChunkShardingSpec(0, ["rank:0/random:0", "cuda:1"])
|
||||
with self.assertRaisesRegex(ValueError, "not a valid device"):
|
||||
ChunkShardingSpec(0, ["rank:0/cuda:foo", "cuda:1"])
|
||||
|
|
@ -161,6 +161,7 @@ TESTS = [
|
|||
'distributed/elastic/utils/util_test',
|
||||
'distributed/elastic/utils/distributed_test',
|
||||
'distributed/elastic/multiprocessing/api_test',
|
||||
'distributed/_sharding_spec/test_sharding_spec',
|
||||
]
|
||||
|
||||
# Tests need to be run with pytest.
|
||||
|
|
|
|||
5
torch/distributed/_sharding_spec/__init__.py
Normal file
5
torch/distributed/_sharding_spec/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from .api import (
|
||||
PlacementSpec,
|
||||
DevicePlacementSpec,
|
||||
ChunkShardingSpec
|
||||
)
|
||||
22
torch/distributed/_sharding_spec/_internals.py
Normal file
22
torch/distributed/_sharding_spec/_internals.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
import torch
|
||||
from torch.distributed.utils import _parse_remote_device
|
||||
|
||||
def is_valid_device(device):
|
||||
"""
|
||||
Checks if this is a valid local/remote device.
|
||||
"""
|
||||
# Check for torch.device
|
||||
try:
|
||||
torch.device(device)
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check for remote device.
|
||||
try:
|
||||
_parse_remote_device(device)
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False
|
||||
100
torch/distributed/_sharding_spec/api.py
Normal file
100
torch/distributed/_sharding_spec/api.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
from abc import ABC
|
||||
import torch
|
||||
from typing import List, Union
|
||||
|
||||
from ._internals import is_valid_device
|
||||
|
||||
Device = Union[torch.device, int, str]
|
||||
|
||||
class PlacementSpec(ABC):
|
||||
"""
|
||||
Base class representing the placement of an entity. Subclasses of this
|
||||
class can be used to specify customized placements which might not be
|
||||
covered by existing APIs.
|
||||
"""
|
||||
pass
|
||||
|
||||
class DevicePlacementSpec(PlacementSpec):
|
||||
"""
|
||||
Associates placement of an entity with a single device. The device can be a
|
||||
local device or a remote device specified by one of the following remote
|
||||
formats:
|
||||
|
||||
1. "rank:<rank>/<device>" (ex: "rank:0/cuda:0").
|
||||
2. "<worker_name>/<device>" (ex: "trainer0/cuda:0").
|
||||
|
||||
Args:
|
||||
device(str, :class:`torch.device`): The device to place the entity on.
|
||||
"""
|
||||
def __init__(self, device: Device):
|
||||
super(DevicePlacementSpec, self).__init__()
|
||||
if not is_valid_device(device):
|
||||
raise ValueError(f'{device} is not a valid device')
|
||||
self._device = device
|
||||
|
||||
@property
|
||||
def device(self) -> Device:
|
||||
"""
|
||||
Retrieves the device for placement.
|
||||
"""
|
||||
return self._device
|
||||
|
||||
class ChunkShardingSpec(PlacementSpec):
|
||||
"""
|
||||
This is type of PlacementSpec that defines the placement as being sharded
|
||||
across multiple devices. In particular, it represents sharding a Tensor
|
||||
along a single dimension into equal chunks (similar to :meth:`torch.chunk`).
|
||||
|
||||
Args:
|
||||
dim (int or str):
|
||||
The dimension to shard on, could be an integer representing the
|
||||
dimension or a string in case of named tensors where dimensions are
|
||||
named.
|
||||
placement(List[Device] or List[PlacementSpec]):
|
||||
Specifies the placement of each shard of the Tensor. The size of
|
||||
the list represents the number of shards to be created. This
|
||||
parameter can be a list of devices
|
||||
(ex: ["rank:0/cuda:0", "rank:1/cuda:1"]) or a list of custom
|
||||
placement specs.
|
||||
|
||||
The device can be a local device or a remote device specified by one
|
||||
of the following remote formats:
|
||||
|
||||
1. "rank:<rank>/<device>" (ex: "rank:0/cuda:0").
|
||||
2. "<worker_name>/<device>" (ex: "trainer0/cuda:0").
|
||||
"""
|
||||
|
||||
ShardPlacements = List[Union[Device, PlacementSpec]]
|
||||
ShardingDim = Union[int, str]
|
||||
|
||||
def __init__(self, dim: ShardingDim, placements: ShardPlacements):
|
||||
super(ChunkShardingSpec, self).__init__()
|
||||
self._verify_dim(dim)
|
||||
self._verify_devices(placements)
|
||||
self._dim = dim
|
||||
self._placements = placements
|
||||
|
||||
@staticmethod
|
||||
def _verify_devices(placements):
|
||||
for dev in placements:
|
||||
if not isinstance(dev, PlacementSpec) and not is_valid_device(dev):
|
||||
raise ValueError(f'{dev} is not a valid device')
|
||||
|
||||
@staticmethod
|
||||
def _verify_dim(dim):
|
||||
if not (isinstance(dim, int) or isinstance(dim, str)):
|
||||
raise ValueError(f'{dim} needs to either be an int or str')
|
||||
|
||||
@property
|
||||
def dim(self) -> ShardingDim:
|
||||
"""
|
||||
Retrieves the dimension to shard on.
|
||||
"""
|
||||
return self._dim
|
||||
|
||||
@property
|
||||
def placements(self) -> ShardPlacements:
|
||||
"""
|
||||
Retrieves the shard placements.
|
||||
"""
|
||||
return self._placements
|
||||
|
|
@ -20,7 +20,7 @@ import torch.distributed.rpc as rpc
|
|||
from torch import Tensor, device, dtype, nn
|
||||
from torch.distributed.nn.jit import instantiator
|
||||
from torch.distributed.rpc.internal import _internal_rpc_pickler
|
||||
from torch.distributed.rpc.utils import _parse_remote_device
|
||||
from torch.distributed.utils import _parse_remote_device
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
|
@ -154,8 +154,12 @@ class _RemoteModule(nn.Module):
|
|||
|
||||
Args:
|
||||
remote_device (str): Device on the destination worker where we'd like to place this module.
|
||||
The format should be "<workername>/<device>", where the device field can be parsed as torch.device type.
|
||||
E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0".
|
||||
The device can be a local device or a remote device specified by one of the following remote
|
||||
formats:
|
||||
|
||||
1. "rank:<rank>/<device>" (ex: "rank:0/cuda:0").
|
||||
2. "<worker_name>/<device>" (ex: "trainer0/cuda:0").
|
||||
|
||||
In addition, the device field can be optional and the default value is "cpu".
|
||||
module_cls (nn.Module): For example,
|
||||
>>> class MyModule(nn.Module):
|
||||
|
|
|
|||
|
|
@ -1,37 +0,0 @@
|
|||
def _parse_remote_device(remote_device: str):
|
||||
r"""
|
||||
Parses the remote device.
|
||||
|
||||
Args:
|
||||
remote_device (str): Device on the destination worker where we'd like to place this module.
|
||||
The format should be "<workername>/<device>", where the device field can be parsed as torch.device type.
|
||||
E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0".
|
||||
In addition, the device field can be optional and the default value is "cpu".
|
||||
|
||||
Returns:
|
||||
A workername and a device.
|
||||
"""
|
||||
fields = remote_device.split("/")
|
||||
if len(fields) == 2:
|
||||
[on, device] = fields
|
||||
elif len(fields) == 1:
|
||||
on = fields[0]
|
||||
device = "cpu"
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Could not parse remote_device: {}. The valid format is '<workername>/<device>'".format(
|
||||
remote_device
|
||||
)
|
||||
)
|
||||
|
||||
# Since the workername in the input remote device won't be validated until the created remote module is executed,
|
||||
# only do some very basic sanity check on workername at the module creation time.
|
||||
# As currently there is no regex to describe the format of workername, just check whether the workername is empty.
|
||||
if not on:
|
||||
raise RuntimeError(
|
||||
"The workername in remote_device '{}' cannot be empty. The valid format is '<workername>/<device>'".format(
|
||||
remote_device
|
||||
)
|
||||
)
|
||||
|
||||
return on, device
|
||||
56
torch/distributed/utils.py
Normal file
56
torch/distributed/utils.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
import torch
|
||||
|
||||
def _parse_remote_device(remote_device: str):
|
||||
r"""
|
||||
Parses the remote device.
|
||||
|
||||
Args:
|
||||
remote_device (str): Device on the destination worker where we'd like to place this module.
|
||||
The format should be one of the following:
|
||||
|
||||
1. "<workername>/<device>", where the device field can be parsed as torch.device type.
|
||||
E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0".
|
||||
In addition, the device field can be optional and the default value is "cpu".
|
||||
2. "rank:<rank>/<device>", where <rank> is the rank of the
|
||||
process and device can be parsed as torch.device type.
|
||||
E.g., "rank:0/cpu", "rank:0", "rank:0/cuda:0"
|
||||
|
||||
Returns:
|
||||
A workername/rank and a device.
|
||||
"""
|
||||
|
||||
PARSE_ERROR = (
|
||||
f"Could not parse remote_device: {remote_device}. The valid format is "
|
||||
"'<workername>/<device>' or 'rank:<rank>/<device>'"
|
||||
)
|
||||
|
||||
fields = remote_device.split("/")
|
||||
if len(fields) == 2:
|
||||
[on, device] = fields
|
||||
elif len(fields) == 1:
|
||||
on = fields[0]
|
||||
device = "cpu"
|
||||
else:
|
||||
raise ValueError(PARSE_ERROR)
|
||||
|
||||
# Since the workername in the input remote device won't be validated until the created remote module is executed,
|
||||
# only do some very basic sanity check on workername at the module creation time.
|
||||
# As currently there is no regex to describe the format of workername, just check whether the workername is empty.
|
||||
if not on:
|
||||
raise ValueError(PARSE_ERROR)
|
||||
|
||||
# Validate the device.
|
||||
torch.device(device)
|
||||
|
||||
# Check for rank based format
|
||||
fields = on.split(':')
|
||||
if len(fields) == 2:
|
||||
# rank:<rank>/device format, extract rank
|
||||
if fields[0] == 'rank' and fields[1].isdigit():
|
||||
on = int(fields[1]) # type: ignore[assignment]
|
||||
else:
|
||||
raise ValueError(PARSE_ERROR)
|
||||
elif len(fields) > 2:
|
||||
raise ValueError(PARSE_ERROR)
|
||||
|
||||
return on, device
|
||||
|
|
@ -554,7 +554,8 @@ class CudaRemoteModuleTest(CommonRemoteModuleTest):
|
|||
def test_valid_device(self):
|
||||
if self.rank != 0:
|
||||
return
|
||||
dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
|
||||
dst_rank = (self.rank + 1) % self.world_size
|
||||
dst_worker_name = dist_utils.worker_name(dst_rank)
|
||||
|
||||
for remote_module in self._create_remote_module_iter(
|
||||
"{}/cuda:0".format(dst_worker_name), modes=[ModuleCreationMode.MODULE_CTOR]
|
||||
|
|
@ -565,6 +566,16 @@ class CudaRemoteModuleTest(CommonRemoteModuleTest):
|
|||
self.assertEqual(device.type, "cuda")
|
||||
self.assertEqual(device.index, 0)
|
||||
|
||||
# Test rank works as well.
|
||||
for remote_module in self._create_remote_module_iter(
|
||||
"rank:{}/cuda:0".format(dst_rank), modes=[ModuleCreationMode.MODULE_CTOR]
|
||||
):
|
||||
device = rpc.rpc_sync(
|
||||
dst_worker_name, remote_device, (remote_module.module_rref,)
|
||||
)
|
||||
self.assertEqual(device.type, "cuda")
|
||||
self.assertEqual(device.index, 0)
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
@dist_utils.dist_init
|
||||
def test_invalid_devices(self):
|
||||
|
|
@ -614,7 +625,7 @@ class CudaRemoteModuleTest(CommonRemoteModuleTest):
|
|||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
ValueError,
|
||||
r"Could not parse remote_device: worker1/cuda:0/cuda:1. The valid format is '<workername>/<device>'",
|
||||
):
|
||||
list(
|
||||
|
|
@ -626,8 +637,8 @@ class CudaRemoteModuleTest(CommonRemoteModuleTest):
|
|||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"The workername in remote_device '/' cannot be empty. The valid format is '<workername>/<device>'",
|
||||
ValueError,
|
||||
r"Could not parse remote_device: /. The valid format is '<workername>/<device>'",
|
||||
):
|
||||
list(
|
||||
m.forward() for m in
|
||||
|
|
@ -638,8 +649,8 @@ class CudaRemoteModuleTest(CommonRemoteModuleTest):
|
|||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"The workername in remote_device '/cuda:0' cannot be empty. The valid format is '<workername>/<device>'",
|
||||
ValueError,
|
||||
r"Could not parse remote_device: /cuda:0. The valid format is '<workername>/<device>'",
|
||||
):
|
||||
list(
|
||||
m.forward() for m in
|
||||
|
|
|
|||
Loading…
Reference in a new issue