From 9e52b50e347d980d01a6d31a62a2acb04442aed3 Mon Sep 17 00:00:00 2001 From: pritam <9958665+pritamdamania87@users.noreply.github.com> Date: Fri, 6 May 2022 09:04:08 -0700 Subject: [PATCH] Additional ops for ShardedTensor, ReplicatedTensor and PartialTensor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/76477 Adding the following ops: 1) softmax for ShardedTensor 2) getitem and unsqueeze for ReplicatedTensor 3) transpose and cat for PartialTensor Differential Revision: [D35979510](https://our.internmc.facebook.com/intern/diff/D35979510/) Approved by: https://github.com/fduwjj, https://github.com/wanchaol --- .../_shard/sharded_tensor/ops/test_softmax.py | 57 ++++++++++++ .../distributed/_shard/test_partial_tensor.py | 53 ++++++++++- .../_shard/test_replicated_tensor.py | 44 +++++++-- test/run_test.py | 2 + torch/distributed/_shard/partial_tensor.py | 90 +++++++++++++++++++ torch/distributed/_shard/replicated_tensor.py | 21 ++++- .../_shard/sharded_tensor/_ops/__init__.py | 1 + .../sharded_tensor/_ops/elementwise_ops.py | 1 + .../chunk_sharding_spec_ops/matrix_ops.py | 24 ----- .../chunk_sharding_spec_ops/softmax.py | 30 +++++++ 10 files changed, 287 insertions(+), 36 deletions(-) create mode 100644 test/distributed/_shard/sharded_tensor/ops/test_softmax.py create mode 100644 torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/softmax.py diff --git a/test/distributed/_shard/sharded_tensor/ops/test_softmax.py b/test/distributed/_shard/sharded_tensor/ops/test_softmax.py new file mode 100644 index 00000000000..e2e5c17a1f7 --- /dev/null +++ b/test/distributed/_shard/sharded_tensor/ops/test_softmax.py @@ -0,0 +1,57 @@ +# Owner(s): ["oncall: distributed"] + +import sys +import torch +from torch.testing._internal.common_utils import ( + TEST_WITH_DEV_DBG_ASAN, + run_tests, +) +from torch.testing._internal.distributed._shard.sharded_tensor import ( + TEST_GPU_NUM, + ShardedTensorTestBase, + with_comms, +) +from torch.testing._internal.common_distributed import ( + requires_nccl, + skip_if_lt_x_gpu, +) +from torch.distributed._shard.sharding_spec import ChunkShardingSpec +from torch.distributed._shard import _shard_tensor + +if TEST_WITH_DEV_DBG_ASAN: + print( + "Skip dev-asan as torch + multiprocessing spawn have known issues", + file=sys.stderr, + ) + sys.exit(0) + + +class TestShardedSoftmax(ShardedTensorTestBase): + + def _test_sharded_softmax(self, dim): + torch.manual_seed(0) + local_tensor = torch.rand(10, 10, device=self.rank) + local_softmax = torch.nn.functional.softmax(local_tensor, dim) + + spec = ChunkShardingSpec(dim=0, placements=[f'rank:{idx}/cuda:{idx}' for idx in range(self.world_size)]) + st = _shard_tensor(local_tensor, spec) + sharded_softmax = torch.nn.functional.softmax(st, dim) + + self.assertEqual(local_softmax.chunk(self.world_size)[self.rank], sharded_softmax.local_tensor()) + + @with_comms(init_rpc=False) + @skip_if_lt_x_gpu(TEST_GPU_NUM) + @requires_nccl() + def test_sharded_softmax_basic(self): + self._test_sharded_softmax(0) + self._test_sharded_softmax(-2) + + @with_comms(init_rpc=False) + @skip_if_lt_x_gpu(TEST_GPU_NUM) + @requires_nccl() + def test_sharded_softmax_on_sharding_dim(self): + self._test_sharded_softmax(1) + self._test_sharded_softmax(-1) + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_shard/test_partial_tensor.py b/test/distributed/_shard/test_partial_tensor.py index 243240e2a19..7c544c8fb7d 100644 --- a/test/distributed/_shard/test_partial_tensor.py +++ b/test/distributed/_shard/test_partial_tensor.py @@ -22,6 +22,7 @@ from torch.testing._internal.common_utils import ( from torch.testing._internal.distributed._shard.sharded_tensor import ( ShardedTensorTestBase, with_comms, + TEST_GPU_NUM ) from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common import ( _chunk_sharding_specs_list_for_test, @@ -70,7 +71,7 @@ class TestPartialTensorReshard(ShardedTensorTestBase): self.assertEqual(local_shards[0].tensor, local_result_compare) @with_comms(init_rpc=False) - @skip_if_lt_x_gpu(4) + @skip_if_lt_x_gpu(TEST_GPU_NUM) @requires_nccl() def test_partial_tensor_reshard(self): specs = _chunk_sharding_specs_list_for_test([0], seed=7) @@ -81,7 +82,7 @@ class TestPartialTensorReshard(ShardedTensorTestBase): self._run_partial_tensor_n_reshard(spec, [17, 21], 2, dist.ReduceOp.MAX) @with_comms(init_rpc=False) - @skip_if_lt_x_gpu(4) + @skip_if_lt_x_gpu(TEST_GPU_NUM) @requires_nccl() def test_partial_tensor_reshard_errors(self): enumerable_sharding_spec = EnumerableShardingSpec( @@ -119,6 +120,54 @@ class TestPartialTensorReshard(ShardedTensorTestBase): spec, [12, 22], 4, dist.ReduceOp.MAX, dtype=torch.cfloat ) + @with_comms(init_rpc=False) + @skip_if_lt_x_gpu(TEST_GPU_NUM) + @requires_nccl() + def test_transpose(self): + partial_tensor = _PartialTensor(torch.rand(5, 10)) + partial_tensor = partial_tensor.transpose(0, 1) + self.assertEqual(partial_tensor.size(), torch.Size((10, 5))) + + @with_comms(init_rpc=False) + @skip_if_lt_x_gpu(TEST_GPU_NUM) + @requires_nccl() + def test_cat(self): + t1 = torch.rand(5, 10) + t2 = torch.rand(3, 10) + t3 = torch.rand(4, 10) + partial_tensors = [_PartialTensor(t1), _PartialTensor(t2), _PartialTensor(t3)] + partial_concat = torch.cat(partial_tensors) + local_concat = torch.cat([t1, t2, t3]) + self.assertEqual(local_concat.size(), partial_concat.size()) + + # Test dim kwarg + t1 = torch.rand(5, 10) + t2 = torch.rand(5, 12) + t3 = torch.rand(5, 11) + partial_tensors = [_PartialTensor(t1), _PartialTensor(t2), _PartialTensor(t3)] + partial_concat = torch.cat(partial_tensors, dim=1) + local_concat = torch.cat([t1, t2, t3], dim=1) + self.assertEqual(local_concat.size(), partial_concat.size()) + + @with_comms(init_rpc=False) + @skip_if_lt_x_gpu(TEST_GPU_NUM) + @requires_nccl() + def test_cat_errors(self): + with self.assertRaisesRegex( + RuntimeError, 'All inputs need to be an instance of _PartialTensor' + ): + torch.cat([_PartialTensor(torch.rand(10)), torch.rand(10)]) + + with self.assertRaisesRegex( + RuntimeError, 'reduce_ops need to be the same' + ): + torch.cat([_PartialTensor(torch.rand(10)), _PartialTensor(torch.rand(10), reduce_op=dist.ReduceOp.MAX)]) + + with self.assertRaisesRegex( + RuntimeError, '"out" kwarg is not supported' + ): + torch.cat([_PartialTensor(torch.rand(10)), _PartialTensor(torch.rand(10))], out=torch.rand(10)) + if __name__ == "__main__": run_tests() diff --git a/test/distributed/_shard/test_replicated_tensor.py b/test/distributed/_shard/test_replicated_tensor.py index e0cb51e9ba4..36dc6cfabdf 100644 --- a/test/distributed/_shard/test_replicated_tensor.py +++ b/test/distributed/_shard/test_replicated_tensor.py @@ -22,17 +22,17 @@ from torch.testing._internal.distributed._shard.sharded_tensor import ( from torch.testing._internal.distributed._shard.sharded_tensor._test_ops_common import ( gen_binary_op_func ) +from torch.testing._internal.distributed._shard.sharded_tensor import TEST_GPU_NUM class TestReplicatedTensor(ShardedTensorTestBase): @with_comms(init_rpc=False) - @skip_if_lt_x_gpu(4) + @skip_if_lt_x_gpu(TEST_GPU_NUM) @requires_nccl() def test_replicated_tensor_basics(self): local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4 replica_tensor = ReplicatedTensor(local_tensor) - print(replica_tensor.process_group) # validate it's a replicated tensor by checking values on all rank validated = replica_tensor.validate() self.assertEqual(validated, True) @@ -49,7 +49,7 @@ class TestReplicatedTensor(ShardedTensorTestBase): replica_tensor.validate() @with_comms(init_rpc=False) - @skip_if_lt_x_gpu(4) + @skip_if_lt_x_gpu(TEST_GPU_NUM) @requires_nccl() def test_replicated_tensor_inter_op_replicated_tensor(self): local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") @@ -69,7 +69,7 @@ class TestReplicatedTensor(ShardedTensorTestBase): @with_comms(init_rpc=False) - @skip_if_lt_x_gpu(4) + @skip_if_lt_x_gpu(TEST_GPU_NUM) @requires_nccl() def test_replicated_tensor_inter_op_tensor(self): local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4 @@ -84,7 +84,7 @@ class TestReplicatedTensor(ShardedTensorTestBase): self.assertEqual(new_tensor, local_tensor + local_rand_tensor) @with_comms(init_rpc=False) - @skip_if_lt_x_gpu(4) + @skip_if_lt_x_gpu(TEST_GPU_NUM) @requires_nccl() def test_replicated_tensor_inter_op_sharded_tensor(self): torch.manual_seed(self.rank) @@ -132,7 +132,7 @@ class TestReplicatedTensor(ShardedTensorTestBase): @with_comms(init_rpc=False) - @skip_if_lt_x_gpu(4) + @skip_if_lt_x_gpu(TEST_GPU_NUM) @requires_nccl() def test_replicated_tensor_implicit_broadcasting(self): # use same seed @@ -174,7 +174,7 @@ class TestReplicatedTensor(ShardedTensorTestBase): @with_comms(init_rpc=False) - @skip_if_lt_x_gpu(4) + @skip_if_lt_x_gpu(TEST_GPU_NUM) @requires_nccl() def test_replicated_tensor_inter_op_sharded_tensor_errors(self): local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4 @@ -201,7 +201,7 @@ class TestReplicatedTensor(ShardedTensorTestBase): st1 % replica_tensor @with_comms(init_rpc=False) - @skip_if_lt_x_gpu(4) + @skip_if_lt_x_gpu(TEST_GPU_NUM) @requires_nccl() def test_with_ddp(self): # Test Replicated params for DDP @@ -306,3 +306,31 @@ class TestReplicatedTensor(ShardedTensorTestBase): buffer.seek(0) obj = torch.load(buffer) self.assertEqual(expected_state_dict, obj.state_dict()) + + @with_comms(init_rpc=False) + @skip_if_lt_x_gpu(TEST_GPU_NUM) + @requires_nccl() + def test_unsqueeze(self): + local_tensor = torch.rand(3, 3, device=self.rank) + replicated_tensor = ReplicatedTensor(local_tensor) + + unsqueezed_replicated_tensor = replicated_tensor.unsqueeze(0) + unsqueezed_local_tensor = local_tensor.unsqueeze(0) + + self.assertIsInstance(unsqueezed_replicated_tensor, ReplicatedTensor) + self.assertIsInstance(torch.unsqueeze(replicated_tensor, 0), ReplicatedTensor) + self.assertEqual(unsqueezed_local_tensor, unsqueezed_replicated_tensor) + self.assertEqual(torch.unsqueeze(replicated_tensor, 0), unsqueezed_replicated_tensor) + + @with_comms(init_rpc=False) + @skip_if_lt_x_gpu(TEST_GPU_NUM) + @requires_nccl() + def test_getitem(self): + local_tensor = torch.rand(3, 3, device=self.rank) + replicated_tensor = ReplicatedTensor(local_tensor) + + replicated_tensor_view = replicated_tensor[0] + local_tensor_view = local_tensor[0] + + self.assertIsInstance(replicated_tensor_view, ReplicatedTensor) + self.assertEqual(local_tensor_view, replicated_tensor_view) diff --git a/test/run_test.py b/test/run_test.py index 2a1b913a10a..57083e8c8f1 100644 --- a/test/run_test.py +++ b/test/run_test.py @@ -215,6 +215,7 @@ WINDOWS_BLOCKLIST = [ "distributed/_shard/sharded_tensor/ops/test_linear", "distributed/_shard/sharded_tensor/ops/test_math_ops", "distributed/_shard/sharded_tensor/ops/test_matrix_ops", + "distributed/_shard/sharded_tensor/ops/test_softmax", "distributed/_shard/sharding_spec/test_sharding_spec", "distributed/_shard/sharded_optim/test_sharded_optim", "distributed/_shard/test_partial_tensor", @@ -240,6 +241,7 @@ ROCM_BLOCKLIST = [ "distributed/_shard/sharded_tensor/ops/test_linear", "distributed/_shard/sharded_tensor/ops/test_math_ops", "distributed/_shard/sharded_tensor/ops/test_matrix_ops", + "distributed/_shard/sharded_tensor/ops/test_softmax", "distributed/_shard/sharding_spec/test_sharding_spec", "distributed/_shard/sharded_optim/test_sharded_optim", "distributed/_shard/test_partial_tensor", diff --git a/torch/distributed/_shard/partial_tensor.py b/torch/distributed/_shard/partial_tensor.py index b6cff1cfa3a..6ef8a8146ea 100644 --- a/torch/distributed/_shard/partial_tensor.py +++ b/torch/distributed/_shard/partial_tensor.py @@ -1,3 +1,6 @@ +import functools +from typing import Callable, Dict + import torch import torch.distributed as dist import torch.distributed._shard.sharding_spec as shard_spec @@ -6,7 +9,36 @@ from torch.distributed._shard.sharded_tensor.api import ShardedTensor from torch.distributed.nn.functional import ( reduce_scatter, ) +from torch.overrides import handle_torch_function +# Custom sharded ops +_PARTIAL_TENSOR_OPS: Dict[Callable, Callable] = {} +def _register_partial_tensor_op(op, func): + from inspect import signature + if len(signature(func).parameters) != 3: + raise TypeError( + f'Partial tensor op function expects signature: ' + f'(types, args, kwargs), but received ' + f'signature: {signature(func)}') + + global _PARTIAL_TENSOR_OPS + _PARTIAL_TENSOR_OPS[op] = func + +def _custom_partial_tensor_op(func): + """ + Decorate for custom partial tensor op + Args: + func(Callable): Torch function for which we want to provide a PartialTensor + implementation (ex: torch.nn.functional.linear) + """ + def decorator_sharded_func(wrapped_func): + _register_partial_tensor_op(func, wrapped_func) + + @functools.wraps(wrapped_func) + def wrapper(*args, **kwargs): + return wrapped_func(*args, **kwargs) + return wrapper + return decorator_sharded_func class _PartialTensor(object): """ @@ -175,3 +207,61 @@ class _PartialTensor(object): sharded_tensor_size, process_group=self.process_group, ) + + def size(self): + return self.local_shard.size() + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if func in _PARTIAL_TENSOR_OPS: + return _PARTIAL_TENSOR_OPS[func](types, args, kwargs) + + raise RuntimeError( + f"torch function '{func.__name__}', with args: {args} and " + f"kwargs: {kwargs} not supported for PartialTensor!") + + def transpose(self, dim0, dim1): + return handle_torch_function(torch.Tensor.transpose, (self, dim0, dim1), self, dim0, dim1) + +def _transpose_impl(types, args=(), kwargs=None): + input = args[0] + dim0 = args[1] + dim1 = args[2] + return _PartialTensor( + torch.transpose(input.local_shard, dim0, dim1), + input.process_group, + input.reduce_op + ) + +@_custom_partial_tensor_op(torch.Tensor.transpose) +def partial_transpose(types, args=(), kwargs=None): + return _transpose_impl(types, args, kwargs) + +@_custom_partial_tensor_op(torch.transpose) +def partial_torch_transpose(types, args=(), kwargs=None): + return _transpose_impl(types, args, kwargs) + +@_custom_partial_tensor_op(torch.cat) +def partial_cat(types, args=(), kwargs=None): + input_list = args[0] + if len(input_list) == 0: + raise RuntimeError('Empty list of tensors to torch.cat!') + + local_shards = [] + for idx, input in enumerate(input_list): + if not isinstance(input, _PartialTensor): + raise RuntimeError('All inputs need to be an instance of _PartialTensor') + if idx == 0: + reduce_op = input.reduce_op + elif reduce_op != input.reduce_op: + raise RuntimeError('All _PartialTensor reduce_ops need to be the same, found: {reduce_op} and {input.reduce_op}') + + local_shards.append(input.local_shard) + + if kwargs is None: + dim = 0 + else: + if 'out' in kwargs: + raise RuntimeError('"out" kwarg is not supported!') + dim = kwargs['dim'] if 'dim' in kwargs else 0 + return _PartialTensor(torch.cat(local_shards, dim), input.process_group, input.reduce_op) diff --git a/torch/distributed/_shard/replicated_tensor.py b/torch/distributed/_shard/replicated_tensor.py index 61b329a3d37..9e44d98c297 100644 --- a/torch/distributed/_shard/replicated_tensor.py +++ b/torch/distributed/_shard/replicated_tensor.py @@ -5,6 +5,14 @@ from torch.distributed._shard.sharded_tensor.api import ShardedTensor from torch.distributed import distributed_c10d from torch.overrides import get_default_nowrap_functions +_REPLICATED_WITH_NON_TENSOR_ALLOWLIST = [ + # List of ops where if parameters are a combination of ReplicatedTensors + # and non-tensors, we can still return a ReplicatedTensor as the result. + torch.unsqueeze, + torch.Tensor.unsqueeze, + torch.Tensor.__getitem__, +] + class ReplicatedTensor(torch.Tensor): """ ReplicatedTensor represents a tensor which is replicated across the `world_size` and @@ -60,12 +68,13 @@ class ReplicatedTensor(torch.Tensor): # are all replicated tensor operands, we have to do this to ensure we do not # converting results back to ReplicatedTensor if not all operands are replicated. all_replicated = True + replicated_with_non_tensor = True replicated_pg = None def dispatch_arg(arg): # This function returns a tuple, first element represents whether the op been # executed, the second element represents the result of the execution - nonlocal replicated_pg, all_replicated + nonlocal replicated_pg, all_replicated, replicated_with_non_tensor if isinstance(arg, ShardedTensor): # redispatch to ShardedTensor # TODO: handle ShardedTensor/PartialTensor inter-op with ReplicatedTensor @@ -78,6 +87,9 @@ class ReplicatedTensor(torch.Tensor): f"ReplicatedTensor operands must be in the same process group " f"in torch function '{func.__name__}', but found at least two " f"ReplicatedTensor operands in different process groups! ") + elif isinstance(arg, torch.Tensor): + replicated_with_non_tensor = False + all_replicated = False else: all_replicated = False @@ -101,7 +113,12 @@ class ReplicatedTensor(torch.Tensor): rs = func(*args, **kwargs) if func in get_default_nowrap_functions(): return rs - if all_replicated and isinstance(rs, torch.Tensor) and not isinstance(rs, cls): + + result_not_replicated = isinstance(rs, torch.Tensor) and not isinstance(rs, ReplicatedTensor) + should_convert_to_replicated = all_replicated or ( + replicated_with_non_tensor and func in _REPLICATED_WITH_NON_TENSOR_ALLOWLIST + ) + if result_not_replicated and should_convert_to_replicated: # if all operands are ReplicatedTensors and does not get dispatched to ShardedTensor # __torch_function__, result is a torch.Tensor, then we convert and return a # ReplicatedTensor according to our inter-op rule diff --git a/torch/distributed/_shard/sharded_tensor/_ops/__init__.py b/torch/distributed/_shard/sharded_tensor/_ops/__init__.py index f59c8d30036..081247a81fd 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/__init__.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/__init__.py @@ -9,5 +9,6 @@ from .init import kaiming_uniform_, normal_, uniform_, constant_ from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.linear import sharded_linear from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding import sharded_embedding from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding_bag import sharded_embedding_bag +from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.softmax import sharded_softmax import torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.math_ops import torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.matrix_ops diff --git a/torch/distributed/_shard/sharded_tensor/_ops/elementwise_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/elementwise_ops.py index dc65277b734..eb6c5d54e38 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/elementwise_ops.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/elementwise_ops.py @@ -7,3 +7,4 @@ from ._common import ( _register_sharded_op_on_local_shards(torch.nn.functional.gelu) _register_sharded_op_on_local_shards(torch.nn.functional.relu) _register_sharded_op_on_local_shards(torch.nn.functional.dropout) +_register_sharded_op_on_local_shards(torch.Tensor.tanh) diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/matrix_ops.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/matrix_ops.py index 5bf3fb28c11..26dacac9df1 100644 --- a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/matrix_ops.py +++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/matrix_ops.py @@ -91,30 +91,6 @@ _register_sharded_op_on_local_tensor( ) -def sharded_softmax_check(*args, **kwargs): - """ - Perform extra checks for ``torch.Tensor.softmax`` op for now we don't support - doing softmax on the sharding dim. - - Args: same as ``torch.Tensor.softmax``. - - Return: None - """ - st = args[0] - dim = kwargs.get("dim") - dim = dim if dim is not None else 1 # If no dim specified, softmax use 1 as dim. - if dim == st.sharding_spec().dim: - raise NotImplementedError( - "Only support performing softmax on non-sharding dim now." - ) - - -_register_sharded_op_on_local_tensor( - torch.nn.functional.softmax, - extra_check=sharded_softmax_check, -) - - def sharded_masked_fill_check(*args, **kwargs): """ Perform extra checks for the ``torch.Tensor.masked_fill`` op. diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/softmax.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/softmax.py new file mode 100644 index 00000000000..e670ba04d0d --- /dev/null +++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/softmax.py @@ -0,0 +1,30 @@ +import torch +from torch.distributed._shard.sharded_tensor import ( + ShardedTensor, +) +from torch.distributed._shard.sharding_spec import ChunkShardingSpec +from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op +from ._common import ( + _register_sharded_op_on_local_tensor, +) + +@custom_sharding_spec_op(ChunkShardingSpec, torch.nn.functional.softmax) +def sharded_softmax(types, args=(), kwargs=None): + input = args[0] + pg = input._process_group + dim = kwargs['dim'] + sharding_dim = input.sharding_spec().dim + ndims = len(input.size()) + if dim == sharding_dim or dim + ndims == sharding_dim or sharding_dim + ndims == dim: + exp = torch.exp(input.local_tensor()) + exp_sum = exp.sum(dim=dim).unsqueeze(dim=dim) + exp_sum = torch.distributed.nn.functional.all_reduce(exp_sum, group=pg) + smax = torch.div(exp, exp_sum) + else: + smax = torch.nn.functional.softmax(input.local_tensor()) + return ShardedTensor._init_from_local_tensor(smax, input.sharding_spec(), input.size(), process_group=pg) + +_register_sharded_op_on_local_tensor( + torch.nn.functional.softmax, + customized_func=sharded_softmax, +)