mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Additional ops for ShardedTensor, ReplicatedTensor and PartialTensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76477 Adding the following ops: 1) softmax for ShardedTensor 2) getitem and unsqueeze for ReplicatedTensor 3) transpose and cat for PartialTensor Differential Revision: [D35979510](https://our.internmc.facebook.com/intern/diff/D35979510/) Approved by: https://github.com/fduwjj, https://github.com/wanchaol
This commit is contained in:
parent
c2f362d36c
commit
9e52b50e34
10 changed files with 287 additions and 36 deletions
57
test/distributed/_shard/sharded_tensor/ops/test_softmax.py
Normal file
57
test/distributed/_shard/sharded_tensor/ops/test_softmax.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import sys
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import (
|
||||
TEST_WITH_DEV_DBG_ASAN,
|
||||
run_tests,
|
||||
)
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor import (
|
||||
TEST_GPU_NUM,
|
||||
ShardedTensorTestBase,
|
||||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import (
|
||||
requires_nccl,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
|
||||
from torch.distributed._shard import _shard_tensor
|
||||
|
||||
if TEST_WITH_DEV_DBG_ASAN:
|
||||
print(
|
||||
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
class TestShardedSoftmax(ShardedTensorTestBase):
|
||||
|
||||
def _test_sharded_softmax(self, dim):
|
||||
torch.manual_seed(0)
|
||||
local_tensor = torch.rand(10, 10, device=self.rank)
|
||||
local_softmax = torch.nn.functional.softmax(local_tensor, dim)
|
||||
|
||||
spec = ChunkShardingSpec(dim=0, placements=[f'rank:{idx}/cuda:{idx}' for idx in range(self.world_size)])
|
||||
st = _shard_tensor(local_tensor, spec)
|
||||
sharded_softmax = torch.nn.functional.softmax(st, dim)
|
||||
|
||||
self.assertEqual(local_softmax.chunk(self.world_size)[self.rank], sharded_softmax.local_tensor())
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_sharded_softmax_basic(self):
|
||||
self._test_sharded_softmax(0)
|
||||
self._test_sharded_softmax(-2)
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_sharded_softmax_on_sharding_dim(self):
|
||||
self._test_sharded_softmax(1)
|
||||
self._test_sharded_softmax(-1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
@ -22,6 +22,7 @@ from torch.testing._internal.common_utils import (
|
|||
from torch.testing._internal.distributed._shard.sharded_tensor import (
|
||||
ShardedTensorTestBase,
|
||||
with_comms,
|
||||
TEST_GPU_NUM
|
||||
)
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common import (
|
||||
_chunk_sharding_specs_list_for_test,
|
||||
|
|
@ -70,7 +71,7 @@ class TestPartialTensorReshard(ShardedTensorTestBase):
|
|||
self.assertEqual(local_shards[0].tensor, local_result_compare)
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_partial_tensor_reshard(self):
|
||||
specs = _chunk_sharding_specs_list_for_test([0], seed=7)
|
||||
|
|
@ -81,7 +82,7 @@ class TestPartialTensorReshard(ShardedTensorTestBase):
|
|||
self._run_partial_tensor_n_reshard(spec, [17, 21], 2, dist.ReduceOp.MAX)
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_partial_tensor_reshard_errors(self):
|
||||
enumerable_sharding_spec = EnumerableShardingSpec(
|
||||
|
|
@ -119,6 +120,54 @@ class TestPartialTensorReshard(ShardedTensorTestBase):
|
|||
spec, [12, 22], 4, dist.ReduceOp.MAX, dtype=torch.cfloat
|
||||
)
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_transpose(self):
|
||||
partial_tensor = _PartialTensor(torch.rand(5, 10))
|
||||
partial_tensor = partial_tensor.transpose(0, 1)
|
||||
self.assertEqual(partial_tensor.size(), torch.Size((10, 5)))
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_cat(self):
|
||||
t1 = torch.rand(5, 10)
|
||||
t2 = torch.rand(3, 10)
|
||||
t3 = torch.rand(4, 10)
|
||||
partial_tensors = [_PartialTensor(t1), _PartialTensor(t2), _PartialTensor(t3)]
|
||||
partial_concat = torch.cat(partial_tensors)
|
||||
local_concat = torch.cat([t1, t2, t3])
|
||||
self.assertEqual(local_concat.size(), partial_concat.size())
|
||||
|
||||
# Test dim kwarg
|
||||
t1 = torch.rand(5, 10)
|
||||
t2 = torch.rand(5, 12)
|
||||
t3 = torch.rand(5, 11)
|
||||
partial_tensors = [_PartialTensor(t1), _PartialTensor(t2), _PartialTensor(t3)]
|
||||
partial_concat = torch.cat(partial_tensors, dim=1)
|
||||
local_concat = torch.cat([t1, t2, t3], dim=1)
|
||||
self.assertEqual(local_concat.size(), partial_concat.size())
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_cat_errors(self):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'All inputs need to be an instance of _PartialTensor'
|
||||
):
|
||||
torch.cat([_PartialTensor(torch.rand(10)), torch.rand(10)])
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'reduce_ops need to be the same'
|
||||
):
|
||||
torch.cat([_PartialTensor(torch.rand(10)), _PartialTensor(torch.rand(10), reduce_op=dist.ReduceOp.MAX)])
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, '"out" kwarg is not supported'
|
||||
):
|
||||
torch.cat([_PartialTensor(torch.rand(10)), _PartialTensor(torch.rand(10))], out=torch.rand(10))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -22,17 +22,17 @@ from torch.testing._internal.distributed._shard.sharded_tensor import (
|
|||
from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import (
|
||||
gen_binary_op_func
|
||||
)
|
||||
from torch.testing._internal.distributed._shard.sharded_tensor import TEST_GPU_NUM
|
||||
|
||||
|
||||
class TestReplicatedTensor(ShardedTensorTestBase):
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_replicated_tensor_basics(self):
|
||||
local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4
|
||||
replica_tensor = ReplicatedTensor(local_tensor)
|
||||
print(replica_tensor.process_group)
|
||||
# validate it's a replicated tensor by checking values on all rank
|
||||
validated = replica_tensor.validate()
|
||||
self.assertEqual(validated, True)
|
||||
|
|
@ -49,7 +49,7 @@ class TestReplicatedTensor(ShardedTensorTestBase):
|
|||
replica_tensor.validate()
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_replicated_tensor_inter_op_replicated_tensor(self):
|
||||
local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}")
|
||||
|
|
@ -69,7 +69,7 @@ class TestReplicatedTensor(ShardedTensorTestBase):
|
|||
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_replicated_tensor_inter_op_tensor(self):
|
||||
local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4
|
||||
|
|
@ -84,7 +84,7 @@ class TestReplicatedTensor(ShardedTensorTestBase):
|
|||
self.assertEqual(new_tensor, local_tensor + local_rand_tensor)
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_replicated_tensor_inter_op_sharded_tensor(self):
|
||||
torch.manual_seed(self.rank)
|
||||
|
|
@ -132,7 +132,7 @@ class TestReplicatedTensor(ShardedTensorTestBase):
|
|||
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_replicated_tensor_implicit_broadcasting(self):
|
||||
# use same seed
|
||||
|
|
@ -174,7 +174,7 @@ class TestReplicatedTensor(ShardedTensorTestBase):
|
|||
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_replicated_tensor_inter_op_sharded_tensor_errors(self):
|
||||
local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4
|
||||
|
|
@ -201,7 +201,7 @@ class TestReplicatedTensor(ShardedTensorTestBase):
|
|||
st1 % replica_tensor
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_with_ddp(self):
|
||||
# Test Replicated params for DDP
|
||||
|
|
@ -306,3 +306,31 @@ class TestReplicatedTensor(ShardedTensorTestBase):
|
|||
buffer.seek(0)
|
||||
obj = torch.load(buffer)
|
||||
self.assertEqual(expected_state_dict, obj.state_dict())
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_unsqueeze(self):
|
||||
local_tensor = torch.rand(3, 3, device=self.rank)
|
||||
replicated_tensor = ReplicatedTensor(local_tensor)
|
||||
|
||||
unsqueezed_replicated_tensor = replicated_tensor.unsqueeze(0)
|
||||
unsqueezed_local_tensor = local_tensor.unsqueeze(0)
|
||||
|
||||
self.assertIsInstance(unsqueezed_replicated_tensor, ReplicatedTensor)
|
||||
self.assertIsInstance(torch.unsqueeze(replicated_tensor, 0), ReplicatedTensor)
|
||||
self.assertEqual(unsqueezed_local_tensor, unsqueezed_replicated_tensor)
|
||||
self.assertEqual(torch.unsqueeze(replicated_tensor, 0), unsqueezed_replicated_tensor)
|
||||
|
||||
@with_comms(init_rpc=False)
|
||||
@skip_if_lt_x_gpu(TEST_GPU_NUM)
|
||||
@requires_nccl()
|
||||
def test_getitem(self):
|
||||
local_tensor = torch.rand(3, 3, device=self.rank)
|
||||
replicated_tensor = ReplicatedTensor(local_tensor)
|
||||
|
||||
replicated_tensor_view = replicated_tensor[0]
|
||||
local_tensor_view = local_tensor[0]
|
||||
|
||||
self.assertIsInstance(replicated_tensor_view, ReplicatedTensor)
|
||||
self.assertEqual(local_tensor_view, replicated_tensor_view)
|
||||
|
|
|
|||
|
|
@ -215,6 +215,7 @@ WINDOWS_BLOCKLIST = [
|
|||
"distributed/_shard/sharded_tensor/ops/test_linear",
|
||||
"distributed/_shard/sharded_tensor/ops/test_math_ops",
|
||||
"distributed/_shard/sharded_tensor/ops/test_matrix_ops",
|
||||
"distributed/_shard/sharded_tensor/ops/test_softmax",
|
||||
"distributed/_shard/sharding_spec/test_sharding_spec",
|
||||
"distributed/_shard/sharded_optim/test_sharded_optim",
|
||||
"distributed/_shard/test_partial_tensor",
|
||||
|
|
@ -240,6 +241,7 @@ ROCM_BLOCKLIST = [
|
|||
"distributed/_shard/sharded_tensor/ops/test_linear",
|
||||
"distributed/_shard/sharded_tensor/ops/test_math_ops",
|
||||
"distributed/_shard/sharded_tensor/ops/test_matrix_ops",
|
||||
"distributed/_shard/sharded_tensor/ops/test_softmax",
|
||||
"distributed/_shard/sharding_spec/test_sharding_spec",
|
||||
"distributed/_shard/sharded_optim/test_sharded_optim",
|
||||
"distributed/_shard/test_partial_tensor",
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
import functools
|
||||
from typing import Callable, Dict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._shard.sharding_spec as shard_spec
|
||||
|
|
@ -6,7 +9,36 @@ from torch.distributed._shard.sharded_tensor.api import ShardedTensor
|
|||
from torch.distributed.nn.functional import (
|
||||
reduce_scatter,
|
||||
)
|
||||
from torch.overrides import handle_torch_function
|
||||
|
||||
# Custom sharded ops
|
||||
_PARTIAL_TENSOR_OPS: Dict[Callable, Callable] = {}
|
||||
def _register_partial_tensor_op(op, func):
|
||||
from inspect import signature
|
||||
if len(signature(func).parameters) != 3:
|
||||
raise TypeError(
|
||||
f'Partial tensor op function expects signature: '
|
||||
f'(types, args, kwargs), but received '
|
||||
f'signature: {signature(func)}')
|
||||
|
||||
global _PARTIAL_TENSOR_OPS
|
||||
_PARTIAL_TENSOR_OPS[op] = func
|
||||
|
||||
def _custom_partial_tensor_op(func):
|
||||
"""
|
||||
Decorate for custom partial tensor op
|
||||
Args:
|
||||
func(Callable): Torch function for which we want to provide a PartialTensor
|
||||
implementation (ex: torch.nn.functional.linear)
|
||||
"""
|
||||
def decorator_sharded_func(wrapped_func):
|
||||
_register_partial_tensor_op(func, wrapped_func)
|
||||
|
||||
@functools.wraps(wrapped_func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return wrapped_func(*args, **kwargs)
|
||||
return wrapper
|
||||
return decorator_sharded_func
|
||||
|
||||
class _PartialTensor(object):
|
||||
"""
|
||||
|
|
@ -175,3 +207,61 @@ class _PartialTensor(object):
|
|||
sharded_tensor_size,
|
||||
process_group=self.process_group,
|
||||
)
|
||||
|
||||
def size(self):
|
||||
return self.local_shard.size()
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if func in _PARTIAL_TENSOR_OPS:
|
||||
return _PARTIAL_TENSOR_OPS[func](types, args, kwargs)
|
||||
|
||||
raise RuntimeError(
|
||||
f"torch function '{func.__name__}', with args: {args} and "
|
||||
f"kwargs: {kwargs} not supported for PartialTensor!")
|
||||
|
||||
def transpose(self, dim0, dim1):
|
||||
return handle_torch_function(torch.Tensor.transpose, (self, dim0, dim1), self, dim0, dim1)
|
||||
|
||||
def _transpose_impl(types, args=(), kwargs=None):
|
||||
input = args[0]
|
||||
dim0 = args[1]
|
||||
dim1 = args[2]
|
||||
return _PartialTensor(
|
||||
torch.transpose(input.local_shard, dim0, dim1),
|
||||
input.process_group,
|
||||
input.reduce_op
|
||||
)
|
||||
|
||||
@_custom_partial_tensor_op(torch.Tensor.transpose)
|
||||
def partial_transpose(types, args=(), kwargs=None):
|
||||
return _transpose_impl(types, args, kwargs)
|
||||
|
||||
@_custom_partial_tensor_op(torch.transpose)
|
||||
def partial_torch_transpose(types, args=(), kwargs=None):
|
||||
return _transpose_impl(types, args, kwargs)
|
||||
|
||||
@_custom_partial_tensor_op(torch.cat)
|
||||
def partial_cat(types, args=(), kwargs=None):
|
||||
input_list = args[0]
|
||||
if len(input_list) == 0:
|
||||
raise RuntimeError('Empty list of tensors to torch.cat!')
|
||||
|
||||
local_shards = []
|
||||
for idx, input in enumerate(input_list):
|
||||
if not isinstance(input, _PartialTensor):
|
||||
raise RuntimeError('All inputs need to be an instance of _PartialTensor')
|
||||
if idx == 0:
|
||||
reduce_op = input.reduce_op
|
||||
elif reduce_op != input.reduce_op:
|
||||
raise RuntimeError('All _PartialTensor reduce_ops need to be the same, found: {reduce_op} and {input.reduce_op}')
|
||||
|
||||
local_shards.append(input.local_shard)
|
||||
|
||||
if kwargs is None:
|
||||
dim = 0
|
||||
else:
|
||||
if 'out' in kwargs:
|
||||
raise RuntimeError('"out" kwarg is not supported!')
|
||||
dim = kwargs['dim'] if 'dim' in kwargs else 0
|
||||
return _PartialTensor(torch.cat(local_shards, dim), input.process_group, input.reduce_op)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,14 @@ from torch.distributed._shard.sharded_tensor.api import ShardedTensor
|
|||
from torch.distributed import distributed_c10d
|
||||
from torch.overrides import get_default_nowrap_functions
|
||||
|
||||
_REPLICATED_WITH_NON_TENSOR_ALLOWLIST = [
|
||||
# List of ops where if parameters are a combination of ReplicatedTensors
|
||||
# and non-tensors, we can still return a ReplicatedTensor as the result.
|
||||
torch.unsqueeze,
|
||||
torch.Tensor.unsqueeze,
|
||||
torch.Tensor.__getitem__,
|
||||
]
|
||||
|
||||
class ReplicatedTensor(torch.Tensor):
|
||||
"""
|
||||
ReplicatedTensor represents a tensor which is replicated across the `world_size` and
|
||||
|
|
@ -60,12 +68,13 @@ class ReplicatedTensor(torch.Tensor):
|
|||
# are all replicated tensor operands, we have to do this to ensure we do not
|
||||
# converting results back to ReplicatedTensor if not all operands are replicated.
|
||||
all_replicated = True
|
||||
replicated_with_non_tensor = True
|
||||
replicated_pg = None
|
||||
|
||||
def dispatch_arg(arg):
|
||||
# This function returns a tuple, first element represents whether the op been
|
||||
# executed, the second element represents the result of the execution
|
||||
nonlocal replicated_pg, all_replicated
|
||||
nonlocal replicated_pg, all_replicated, replicated_with_non_tensor
|
||||
if isinstance(arg, ShardedTensor):
|
||||
# redispatch to ShardedTensor
|
||||
# TODO: handle ShardedTensor/PartialTensor inter-op with ReplicatedTensor
|
||||
|
|
@ -78,6 +87,9 @@ class ReplicatedTensor(torch.Tensor):
|
|||
f"ReplicatedTensor operands must be in the same process group "
|
||||
f"in torch function '{func.__name__}', but found at least two "
|
||||
f"ReplicatedTensor operands in different process groups! ")
|
||||
elif isinstance(arg, torch.Tensor):
|
||||
replicated_with_non_tensor = False
|
||||
all_replicated = False
|
||||
else:
|
||||
all_replicated = False
|
||||
|
||||
|
|
@ -101,7 +113,12 @@ class ReplicatedTensor(torch.Tensor):
|
|||
rs = func(*args, **kwargs)
|
||||
if func in get_default_nowrap_functions():
|
||||
return rs
|
||||
if all_replicated and isinstance(rs, torch.Tensor) and not isinstance(rs, cls):
|
||||
|
||||
result_not_replicated = isinstance(rs, torch.Tensor) and not isinstance(rs, ReplicatedTensor)
|
||||
should_convert_to_replicated = all_replicated or (
|
||||
replicated_with_non_tensor and func in _REPLICATED_WITH_NON_TENSOR_ALLOWLIST
|
||||
)
|
||||
if result_not_replicated and should_convert_to_replicated:
|
||||
# if all operands are ReplicatedTensors and does not get dispatched to ShardedTensor
|
||||
# __torch_function__, result is a torch.Tensor, then we convert and return a
|
||||
# ReplicatedTensor according to our inter-op rule
|
||||
|
|
|
|||
|
|
@ -9,5 +9,6 @@ from .init import kaiming_uniform_, normal_, uniform_, constant_
|
|||
from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.linear import sharded_linear
|
||||
from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding import sharded_embedding
|
||||
from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding_bag import sharded_embedding_bag
|
||||
from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.softmax import sharded_softmax
|
||||
import torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.math_ops
|
||||
import torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.matrix_ops
|
||||
|
|
|
|||
|
|
@ -7,3 +7,4 @@ from ._common import (
|
|||
_register_sharded_op_on_local_shards(torch.nn.functional.gelu)
|
||||
_register_sharded_op_on_local_shards(torch.nn.functional.relu)
|
||||
_register_sharded_op_on_local_shards(torch.nn.functional.dropout)
|
||||
_register_sharded_op_on_local_shards(torch.Tensor.tanh)
|
||||
|
|
|
|||
|
|
@ -91,30 +91,6 @@ _register_sharded_op_on_local_tensor(
|
|||
)
|
||||
|
||||
|
||||
def sharded_softmax_check(*args, **kwargs):
|
||||
"""
|
||||
Perform extra checks for ``torch.Tensor.softmax`` op for now we don't support
|
||||
doing softmax on the sharding dim.
|
||||
|
||||
Args: same as ``torch.Tensor.softmax``.
|
||||
|
||||
Return: None
|
||||
"""
|
||||
st = args[0]
|
||||
dim = kwargs.get("dim")
|
||||
dim = dim if dim is not None else 1 # If no dim specified, softmax use 1 as dim.
|
||||
if dim == st.sharding_spec().dim:
|
||||
raise NotImplementedError(
|
||||
"Only support performing softmax on non-sharding dim now."
|
||||
)
|
||||
|
||||
|
||||
_register_sharded_op_on_local_tensor(
|
||||
torch.nn.functional.softmax,
|
||||
extra_check=sharded_softmax_check,
|
||||
)
|
||||
|
||||
|
||||
def sharded_masked_fill_check(*args, **kwargs):
|
||||
"""
|
||||
Perform extra checks for the ``torch.Tensor.masked_fill`` op.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,30 @@
|
|||
import torch
|
||||
from torch.distributed._shard.sharded_tensor import (
|
||||
ShardedTensor,
|
||||
)
|
||||
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
|
||||
from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op
|
||||
from ._common import (
|
||||
_register_sharded_op_on_local_tensor,
|
||||
)
|
||||
|
||||
@custom_sharding_spec_op(ChunkShardingSpec, torch.nn.functional.softmax)
|
||||
def sharded_softmax(types, args=(), kwargs=None):
|
||||
input = args[0]
|
||||
pg = input._process_group
|
||||
dim = kwargs['dim']
|
||||
sharding_dim = input.sharding_spec().dim
|
||||
ndims = len(input.size())
|
||||
if dim == sharding_dim or dim + ndims == sharding_dim or sharding_dim + ndims == dim:
|
||||
exp = torch.exp(input.local_tensor())
|
||||
exp_sum = exp.sum(dim=dim).unsqueeze(dim=dim)
|
||||
exp_sum = torch.distributed.nn.functional.all_reduce(exp_sum, group=pg)
|
||||
smax = torch.div(exp, exp_sum)
|
||||
else:
|
||||
smax = torch.nn.functional.softmax(input.local_tensor())
|
||||
return ShardedTensor._init_from_local_tensor(smax, input.sharding_spec(), input.size(), process_group=pg)
|
||||
|
||||
_register_sharded_op_on_local_tensor(
|
||||
torch.nn.functional.softmax,
|
||||
customized_func=sharded_softmax,
|
||||
)
|
||||
Loading…
Reference in a new issue