Additional ops for ShardedTensor, ReplicatedTensor and PartialTensor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76477

Adding the following ops:

1) softmax for ShardedTensor
2) getitem and unsqueeze for ReplicatedTensor
3) transpose and cat for PartialTensor

Differential Revision: [D35979510](https://our.internmc.facebook.com/intern/diff/D35979510/)

Approved by: https://github.com/fduwjj, https://github.com/wanchaol
This commit is contained in:
pritam 2022-05-06 09:04:08 -07:00 committed by PyTorch MergeBot
parent c2f362d36c
commit 9e52b50e34
10 changed files with 287 additions and 36 deletions

View file

@ -0,0 +1,57 @@
# Owner(s): ["oncall: distributed"]
import sys
import torch
from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN,
run_tests,
)
from torch.testing._internal.distributed._shard.sharded_tensor import (
TEST_GPU_NUM,
ShardedTensorTestBase,
with_comms,
)
from torch.testing._internal.common_distributed import (
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.distributed._shard import _shard_tensor
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class TestShardedSoftmax(ShardedTensorTestBase):
def _test_sharded_softmax(self, dim):
torch.manual_seed(0)
local_tensor = torch.rand(10, 10, device=self.rank)
local_softmax = torch.nn.functional.softmax(local_tensor, dim)
spec = ChunkShardingSpec(dim=0, placements=[f'rank:{idx}/cuda:{idx}' for idx in range(self.world_size)])
st = _shard_tensor(local_tensor, spec)
sharded_softmax = torch.nn.functional.softmax(st, dim)
self.assertEqual(local_softmax.chunk(self.world_size)[self.rank], sharded_softmax.local_tensor())
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_sharded_softmax_basic(self):
self._test_sharded_softmax(0)
self._test_sharded_softmax(-2)
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_sharded_softmax_on_sharding_dim(self):
self._test_sharded_softmax(1)
self._test_sharded_softmax(-1)
if __name__ == "__main__":
run_tests()

View file

@ -22,6 +22,7 @@ from torch.testing._internal.common_utils import (
from torch.testing._internal.distributed._shard.sharded_tensor import (
ShardedTensorTestBase,
with_comms,
TEST_GPU_NUM
)
from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common import (
_chunk_sharding_specs_list_for_test,
@ -70,7 +71,7 @@ class TestPartialTensorReshard(ShardedTensorTestBase):
self.assertEqual(local_shards[0].tensor, local_result_compare)
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_partial_tensor_reshard(self):
specs = _chunk_sharding_specs_list_for_test([0], seed=7)
@ -81,7 +82,7 @@ class TestPartialTensorReshard(ShardedTensorTestBase):
self._run_partial_tensor_n_reshard(spec, [17, 21], 2, dist.ReduceOp.MAX)
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_partial_tensor_reshard_errors(self):
enumerable_sharding_spec = EnumerableShardingSpec(
@ -119,6 +120,54 @@ class TestPartialTensorReshard(ShardedTensorTestBase):
spec, [12, 22], 4, dist.ReduceOp.MAX, dtype=torch.cfloat
)
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_transpose(self):
partial_tensor = _PartialTensor(torch.rand(5, 10))
partial_tensor = partial_tensor.transpose(0, 1)
self.assertEqual(partial_tensor.size(), torch.Size((10, 5)))
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_cat(self):
t1 = torch.rand(5, 10)
t2 = torch.rand(3, 10)
t3 = torch.rand(4, 10)
partial_tensors = [_PartialTensor(t1), _PartialTensor(t2), _PartialTensor(t3)]
partial_concat = torch.cat(partial_tensors)
local_concat = torch.cat([t1, t2, t3])
self.assertEqual(local_concat.size(), partial_concat.size())
# Test dim kwarg
t1 = torch.rand(5, 10)
t2 = torch.rand(5, 12)
t3 = torch.rand(5, 11)
partial_tensors = [_PartialTensor(t1), _PartialTensor(t2), _PartialTensor(t3)]
partial_concat = torch.cat(partial_tensors, dim=1)
local_concat = torch.cat([t1, t2, t3], dim=1)
self.assertEqual(local_concat.size(), partial_concat.size())
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_cat_errors(self):
with self.assertRaisesRegex(
RuntimeError, 'All inputs need to be an instance of _PartialTensor'
):
torch.cat([_PartialTensor(torch.rand(10)), torch.rand(10)])
with self.assertRaisesRegex(
RuntimeError, 'reduce_ops need to be the same'
):
torch.cat([_PartialTensor(torch.rand(10)), _PartialTensor(torch.rand(10), reduce_op=dist.ReduceOp.MAX)])
with self.assertRaisesRegex(
RuntimeError, '"out" kwarg is not supported'
):
torch.cat([_PartialTensor(torch.rand(10)), _PartialTensor(torch.rand(10))], out=torch.rand(10))
if __name__ == "__main__":
run_tests()

View file

@ -22,17 +22,17 @@ from torch.testing._internal.distributed._shard.sharded_tensor import (
from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import (
gen_binary_op_func
)
from torch.testing._internal.distributed._shard.sharded_tensor import TEST_GPU_NUM
class TestReplicatedTensor(ShardedTensorTestBase):
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_replicated_tensor_basics(self):
local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4
replica_tensor = ReplicatedTensor(local_tensor)
print(replica_tensor.process_group)
# validate it's a replicated tensor by checking values on all rank
validated = replica_tensor.validate()
self.assertEqual(validated, True)
@ -49,7 +49,7 @@ class TestReplicatedTensor(ShardedTensorTestBase):
replica_tensor.validate()
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_replicated_tensor_inter_op_replicated_tensor(self):
local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}")
@ -69,7 +69,7 @@ class TestReplicatedTensor(ShardedTensorTestBase):
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_replicated_tensor_inter_op_tensor(self):
local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4
@ -84,7 +84,7 @@ class TestReplicatedTensor(ShardedTensorTestBase):
self.assertEqual(new_tensor, local_tensor + local_rand_tensor)
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_replicated_tensor_inter_op_sharded_tensor(self):
torch.manual_seed(self.rank)
@ -132,7 +132,7 @@ class TestReplicatedTensor(ShardedTensorTestBase):
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_replicated_tensor_implicit_broadcasting(self):
# use same seed
@ -174,7 +174,7 @@ class TestReplicatedTensor(ShardedTensorTestBase):
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_replicated_tensor_inter_op_sharded_tensor_errors(self):
local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4
@ -201,7 +201,7 @@ class TestReplicatedTensor(ShardedTensorTestBase):
st1 % replica_tensor
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_with_ddp(self):
# Test Replicated params for DDP
@ -306,3 +306,31 @@ class TestReplicatedTensor(ShardedTensorTestBase):
buffer.seek(0)
obj = torch.load(buffer)
self.assertEqual(expected_state_dict, obj.state_dict())
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_unsqueeze(self):
local_tensor = torch.rand(3, 3, device=self.rank)
replicated_tensor = ReplicatedTensor(local_tensor)
unsqueezed_replicated_tensor = replicated_tensor.unsqueeze(0)
unsqueezed_local_tensor = local_tensor.unsqueeze(0)
self.assertIsInstance(unsqueezed_replicated_tensor, ReplicatedTensor)
self.assertIsInstance(torch.unsqueeze(replicated_tensor, 0), ReplicatedTensor)
self.assertEqual(unsqueezed_local_tensor, unsqueezed_replicated_tensor)
self.assertEqual(torch.unsqueeze(replicated_tensor, 0), unsqueezed_replicated_tensor)
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(TEST_GPU_NUM)
@requires_nccl()
def test_getitem(self):
local_tensor = torch.rand(3, 3, device=self.rank)
replicated_tensor = ReplicatedTensor(local_tensor)
replicated_tensor_view = replicated_tensor[0]
local_tensor_view = local_tensor[0]
self.assertIsInstance(replicated_tensor_view, ReplicatedTensor)
self.assertEqual(local_tensor_view, replicated_tensor_view)

View file

@ -215,6 +215,7 @@ WINDOWS_BLOCKLIST = [
"distributed/_shard/sharded_tensor/ops/test_linear",
"distributed/_shard/sharded_tensor/ops/test_math_ops",
"distributed/_shard/sharded_tensor/ops/test_matrix_ops",
"distributed/_shard/sharded_tensor/ops/test_softmax",
"distributed/_shard/sharding_spec/test_sharding_spec",
"distributed/_shard/sharded_optim/test_sharded_optim",
"distributed/_shard/test_partial_tensor",
@ -240,6 +241,7 @@ ROCM_BLOCKLIST = [
"distributed/_shard/sharded_tensor/ops/test_linear",
"distributed/_shard/sharded_tensor/ops/test_math_ops",
"distributed/_shard/sharded_tensor/ops/test_matrix_ops",
"distributed/_shard/sharded_tensor/ops/test_softmax",
"distributed/_shard/sharding_spec/test_sharding_spec",
"distributed/_shard/sharded_optim/test_sharded_optim",
"distributed/_shard/test_partial_tensor",

View file

@ -1,3 +1,6 @@
import functools
from typing import Callable, Dict
import torch
import torch.distributed as dist
import torch.distributed._shard.sharding_spec as shard_spec
@ -6,7 +9,36 @@ from torch.distributed._shard.sharded_tensor.api import ShardedTensor
from torch.distributed.nn.functional import (
reduce_scatter,
)
from torch.overrides import handle_torch_function
# Custom sharded ops
_PARTIAL_TENSOR_OPS: Dict[Callable, Callable] = {}
def _register_partial_tensor_op(op, func):
from inspect import signature
if len(signature(func).parameters) != 3:
raise TypeError(
f'Partial tensor op function expects signature: '
f'(types, args, kwargs), but received '
f'signature: {signature(func)}')
global _PARTIAL_TENSOR_OPS
_PARTIAL_TENSOR_OPS[op] = func
def _custom_partial_tensor_op(func):
"""
Decorate for custom partial tensor op
Args:
func(Callable): Torch function for which we want to provide a PartialTensor
implementation (ex: torch.nn.functional.linear)
"""
def decorator_sharded_func(wrapped_func):
_register_partial_tensor_op(func, wrapped_func)
@functools.wraps(wrapped_func)
def wrapper(*args, **kwargs):
return wrapped_func(*args, **kwargs)
return wrapper
return decorator_sharded_func
class _PartialTensor(object):
"""
@ -175,3 +207,61 @@ class _PartialTensor(object):
sharded_tensor_size,
process_group=self.process_group,
)
def size(self):
return self.local_shard.size()
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if func in _PARTIAL_TENSOR_OPS:
return _PARTIAL_TENSOR_OPS[func](types, args, kwargs)
raise RuntimeError(
f"torch function '{func.__name__}', with args: {args} and "
f"kwargs: {kwargs} not supported for PartialTensor!")
def transpose(self, dim0, dim1):
return handle_torch_function(torch.Tensor.transpose, (self, dim0, dim1), self, dim0, dim1)
def _transpose_impl(types, args=(), kwargs=None):
input = args[0]
dim0 = args[1]
dim1 = args[2]
return _PartialTensor(
torch.transpose(input.local_shard, dim0, dim1),
input.process_group,
input.reduce_op
)
@_custom_partial_tensor_op(torch.Tensor.transpose)
def partial_transpose(types, args=(), kwargs=None):
return _transpose_impl(types, args, kwargs)
@_custom_partial_tensor_op(torch.transpose)
def partial_torch_transpose(types, args=(), kwargs=None):
return _transpose_impl(types, args, kwargs)
@_custom_partial_tensor_op(torch.cat)
def partial_cat(types, args=(), kwargs=None):
input_list = args[0]
if len(input_list) == 0:
raise RuntimeError('Empty list of tensors to torch.cat!')
local_shards = []
for idx, input in enumerate(input_list):
if not isinstance(input, _PartialTensor):
raise RuntimeError('All inputs need to be an instance of _PartialTensor')
if idx == 0:
reduce_op = input.reduce_op
elif reduce_op != input.reduce_op:
raise RuntimeError('All _PartialTensor reduce_ops need to be the same, found: {reduce_op} and {input.reduce_op}')
local_shards.append(input.local_shard)
if kwargs is None:
dim = 0
else:
if 'out' in kwargs:
raise RuntimeError('"out" kwarg is not supported!')
dim = kwargs['dim'] if 'dim' in kwargs else 0
return _PartialTensor(torch.cat(local_shards, dim), input.process_group, input.reduce_op)

View file

@ -5,6 +5,14 @@ from torch.distributed._shard.sharded_tensor.api import ShardedTensor
from torch.distributed import distributed_c10d
from torch.overrides import get_default_nowrap_functions
_REPLICATED_WITH_NON_TENSOR_ALLOWLIST = [
# List of ops where if parameters are a combination of ReplicatedTensors
# and non-tensors, we can still return a ReplicatedTensor as the result.
torch.unsqueeze,
torch.Tensor.unsqueeze,
torch.Tensor.__getitem__,
]
class ReplicatedTensor(torch.Tensor):
"""
ReplicatedTensor represents a tensor which is replicated across the `world_size` and
@ -60,12 +68,13 @@ class ReplicatedTensor(torch.Tensor):
# are all replicated tensor operands, we have to do this to ensure we do not
# converting results back to ReplicatedTensor if not all operands are replicated.
all_replicated = True
replicated_with_non_tensor = True
replicated_pg = None
def dispatch_arg(arg):
# This function returns a tuple, first element represents whether the op been
# executed, the second element represents the result of the execution
nonlocal replicated_pg, all_replicated
nonlocal replicated_pg, all_replicated, replicated_with_non_tensor
if isinstance(arg, ShardedTensor):
# redispatch to ShardedTensor
# TODO: handle ShardedTensor/PartialTensor inter-op with ReplicatedTensor
@ -78,6 +87,9 @@ class ReplicatedTensor(torch.Tensor):
f"ReplicatedTensor operands must be in the same process group "
f"in torch function '{func.__name__}', but found at least two "
f"ReplicatedTensor operands in different process groups! ")
elif isinstance(arg, torch.Tensor):
replicated_with_non_tensor = False
all_replicated = False
else:
all_replicated = False
@ -101,7 +113,12 @@ class ReplicatedTensor(torch.Tensor):
rs = func(*args, **kwargs)
if func in get_default_nowrap_functions():
return rs
if all_replicated and isinstance(rs, torch.Tensor) and not isinstance(rs, cls):
result_not_replicated = isinstance(rs, torch.Tensor) and not isinstance(rs, ReplicatedTensor)
should_convert_to_replicated = all_replicated or (
replicated_with_non_tensor and func in _REPLICATED_WITH_NON_TENSOR_ALLOWLIST
)
if result_not_replicated and should_convert_to_replicated:
# if all operands are ReplicatedTensors and does not get dispatched to ShardedTensor
# __torch_function__, result is a torch.Tensor, then we convert and return a
# ReplicatedTensor according to our inter-op rule

View file

@ -9,5 +9,6 @@ from .init import kaiming_uniform_, normal_, uniform_, constant_
from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.linear import sharded_linear
from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding import sharded_embedding
from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding_bag import sharded_embedding_bag
from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.softmax import sharded_softmax
import torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.math_ops
import torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.matrix_ops

View file

@ -7,3 +7,4 @@ from ._common import (
_register_sharded_op_on_local_shards(torch.nn.functional.gelu)
_register_sharded_op_on_local_shards(torch.nn.functional.relu)
_register_sharded_op_on_local_shards(torch.nn.functional.dropout)
_register_sharded_op_on_local_shards(torch.Tensor.tanh)

View file

@ -91,30 +91,6 @@ _register_sharded_op_on_local_tensor(
)
def sharded_softmax_check(*args, **kwargs):
"""
Perform extra checks for ``torch.Tensor.softmax`` op for now we don't support
doing softmax on the sharding dim.
Args: same as ``torch.Tensor.softmax``.
Return: None
"""
st = args[0]
dim = kwargs.get("dim")
dim = dim if dim is not None else 1 # If no dim specified, softmax use 1 as dim.
if dim == st.sharding_spec().dim:
raise NotImplementedError(
"Only support performing softmax on non-sharding dim now."
)
_register_sharded_op_on_local_tensor(
torch.nn.functional.softmax,
extra_check=sharded_softmax_check,
)
def sharded_masked_fill_check(*args, **kwargs):
"""
Perform extra checks for the ``torch.Tensor.masked_fill`` op.

View file

@ -0,0 +1,30 @@
import torch
from torch.distributed._shard.sharded_tensor import (
ShardedTensor,
)
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op
from ._common import (
_register_sharded_op_on_local_tensor,
)
@custom_sharding_spec_op(ChunkShardingSpec, torch.nn.functional.softmax)
def sharded_softmax(types, args=(), kwargs=None):
input = args[0]
pg = input._process_group
dim = kwargs['dim']
sharding_dim = input.sharding_spec().dim
ndims = len(input.size())
if dim == sharding_dim or dim + ndims == sharding_dim or sharding_dim + ndims == dim:
exp = torch.exp(input.local_tensor())
exp_sum = exp.sum(dim=dim).unsqueeze(dim=dim)
exp_sum = torch.distributed.nn.functional.all_reduce(exp_sum, group=pg)
smax = torch.div(exp, exp_sum)
else:
smax = torch.nn.functional.softmax(input.local_tensor())
return ShardedTensor._init_from_local_tensor(smax, input.sharding_spec(), input.size(), process_group=pg)
_register_sharded_op_on_local_tensor(
torch.nn.functional.softmax,
customized_func=sharded_softmax,
)