From 37eb31599cac3995c3938544463e655e35ed33b3 Mon Sep 17 00:00:00 2001 From: pritam <9958665+pritamdamania87@users.noreply.github.com> Date: Sat, 21 May 2022 22:33:58 +0000 Subject: [PATCH] [reland] Add sharding tests to multigpu-test.sh and fix custom operator decorator (#77987) 1. Enabled multigpu tests. 2. Fixed failing multigpu tests. 3. Fixed custom operator decorator to be first preference in operator dispatch. Pull Request resolved: https://github.com/pytorch/pytorch/pull/77987 Approved by: https://github.com/fduwjj, https://github.com/wanchaol, https://github.com/janeyx99 --- .jenkins/pytorch/multigpu-test.sh | 20 ++++++++++++++ .../sharded_tensor/ops/test_matrix_ops.py | 4 +-- .../sharded_tensor/test_sharded_tensor.py | 12 ++++----- test/run_test.py | 8 +++--- .../_shard/sharded_tensor/__init__.py | 15 +++++++++-- .../_shard/sharded_tensor/_ops/_common.py | 6 ++--- .../_shard/sharded_tensor/_ops/binary_cmp.py | 6 ++--- .../_shard/sharded_tensor/_ops/chunk.py | 4 +-- .../_shard/sharded_tensor/_ops/init.py | 12 ++++----- .../_shard/sharded_tensor/_ops/math_ops.py | 4 +-- .../_shard/sharded_tensor/_ops/tensor_ops.py | 26 +++++++++---------- .../distributed/_shard/sharded_tensor/api.py | 13 +++++++--- 12 files changed, 84 insertions(+), 46 deletions(-) diff --git a/.jenkins/pytorch/multigpu-test.sh b/.jenkins/pytorch/multigpu-test.sh index 481619a8dc3..0d673397a23 100755 --- a/.jenkins/pytorch/multigpu-test.sh +++ b/.jenkins/pytorch/multigpu-test.sh @@ -28,4 +28,24 @@ time python test/run_test.py --verbose -i distributed/test_c10d_spawn_nccl time python test/run_test.py --verbose -i distributed/test_store time python test/run_test.py --verbose -i distributed/test_pg_wrapper time python test/run_test.py --verbose -i distributed/rpc/cuda/test_tensorpipe_agent +time python test/run_test.py --verbose -i distributed/_shard/checkpoint/test_checkpoint +time python test/run_test.py --verbose -i distributed/_shard/checkpoint/test_file_system_checkpoint +time python test/run_test.py --verbose -i distributed/_shard/sharding_spec/test_sharding_spec +time python test/run_test.py --verbose -i distributed/_shard/sharding_plan/test_sharding_plan +time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/test_megatron_prototype +time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/test_sharded_tensor +time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/test_sharded_tensor_reshard +time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_chunk +time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_elementwise_ops +time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_embedding +time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_embedding_bag +time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_binary_cmp +time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_init +time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_linear +time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_math_ops +time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_matrix_ops +time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_softmax +time python test/run_test.py --verbose -i distributed/_shard/sharded_optim/test_sharded_optim +time python test/run_test.py --verbose -i distributed/_shard/test_partial_tensor +time python test/run_test.py --verbose -i distributed/_shard/test_replicated_tensor assert_git_not_dirty diff --git a/test/distributed/_shard/sharded_tensor/ops/test_matrix_ops.py b/test/distributed/_shard/sharded_tensor/ops/test_matrix_ops.py index dd074f324df..32b152a8ff3 100644 --- a/test/distributed/_shard/sharded_tensor/ops/test_matrix_ops.py +++ b/test/distributed/_shard/sharded_tensor/ops/test_matrix_ops.py @@ -101,8 +101,8 @@ class TestShardedTensorMatrixOps(ShardedTensorTestBase): enumerable_spec, 10, 10, init_rrefs=False, dtype=torch.double ) with self.assertRaisesRegex( - NotImplementedError, - "Only ChunkShardingSpec supported for 'transpose'", + RuntimeError, + "not supported", ): st.transpose(1, 0) diff --git a/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py b/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py index ae00f47cecf..24318be1edd 100644 --- a/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py +++ b/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py @@ -19,7 +19,7 @@ from torch.distributed._shard.api import ( _reshard_output, ) from torch.distributed._shard.sharded_tensor import ( - sharded_op_impl, + custom_sharded_op_impl, pre_load_state_dict_hook, state_dict_hook, ShardedTensor, @@ -174,7 +174,7 @@ class TestShardParameter(ShardedTensorTestBase): with self.assertRaisesRegex(ValueError, 'does not match with src_rank'): shard_parameter(fc, 'weight', spec, src_rank=self.rank) - with self.assertRaisesRegex(AttributeError, 'Linear have no attribute'): + with self.assertRaisesRegex(AttributeError, 'has no attribute'): shard_parameter(fc, 'foo', spec) with self.assertRaisesRegex(ValueError, 'Expected Linear.bias to be a Tensor, but found str'): @@ -2463,7 +2463,7 @@ class TestShardedTensorCustomOps(ShardedTensorTestBase): @requires_nccl() def test_custom_op(self): - @sharded_op_impl(torch.asin) + @custom_sharded_op_impl(torch.asin) def my_sharded_asin(types, args, kwargs, process_group): return torch.asin(args[0].local_shards()[0].tensor) @@ -2491,7 +2491,7 @@ class TestShardedTensorCustomOps(ShardedTensorTestBase): from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op @custom_sharding_spec_op(ChunkShardingSpec, torch.nn.functional.linear) - def my_sharded_linear(types, args, kwargs): + def my_sharded_linear(types, args, kwargs, process_group): return t spec = ChunkShardingSpec( @@ -2515,12 +2515,12 @@ class TestShardedTensorCustomOps(ShardedTensorTestBase): def test_custom_op_errors(self): with self.assertRaisesRegex(TypeError, 'expects signature'): - @sharded_op_impl(torch.nn.functional.linear) + @custom_sharded_op_impl(torch.nn.functional.linear) def my_op1(types, args, kwargs, process_group, random_param): pass with self.assertRaisesRegex(TypeError, 'expects signature'): - @sharded_op_impl(torch.nn.functional.linear) + @custom_sharded_op_impl(torch.nn.functional.linear) def my_op2(types): pass diff --git a/test/run_test.py b/test/run_test.py index c0ad0a55a02..bc32163fba5 100644 --- a/test/run_test.py +++ b/test/run_test.py @@ -201,6 +201,8 @@ WINDOWS_BLOCKLIST = [ "distributed/pipeline/sync/test_worker", "distributed/elastic/agent/server/test/api_test", "distributed/elastic/multiprocessing/api_test", + "distributed/_shard/checkpoint/test_checkpoint" + "distributed/_shard/checkpoint/test_file_system_checkpoint" "distributed/_shard/sharding_spec/test_sharding_spec", "distributed/_shard/sharding_plan/test_sharding_plan", "distributed/_shard/sharded_tensor/test_megatron_prototype", @@ -216,8 +218,6 @@ WINDOWS_BLOCKLIST = [ "distributed/_shard/sharded_tensor/ops/test_math_ops", "distributed/_shard/sharded_tensor/ops/test_matrix_ops", "distributed/_shard/sharded_tensor/ops/test_softmax", - "distributed/_shard/sharded_tensor/ops/test_tensor_ops", - "distributed/_shard/sharding_spec/test_sharding_spec", "distributed/_shard/sharded_optim/test_sharded_optim", "distributed/_shard/test_partial_tensor", "distributed/_shard/test_replicated_tensor", @@ -228,6 +228,8 @@ ROCM_BLOCKLIST = [ "distributed/rpc/test_faulty_agent", "distributed/rpc/test_tensorpipe_agent", "distributed/rpc/cuda/test_tensorpipe_agent", + "distributed/_shard/checkpoint/test_checkpoint" + "distributed/_shard/checkpoint/test_file_system_checkpoint" "distributed/_shard/sharding_spec/test_sharding_spec", "distributed/_shard/sharding_plan/test_sharding_plan", "distributed/_shard/sharded_tensor/test_megatron_prototype", @@ -243,8 +245,6 @@ ROCM_BLOCKLIST = [ "distributed/_shard/sharded_tensor/ops/test_math_ops", "distributed/_shard/sharded_tensor/ops/test_matrix_ops", "distributed/_shard/sharded_tensor/ops/test_softmax", - "distributed/_shard/sharded_tensor/ops/test_tensor_ops", - "distributed/_shard/sharding_spec/test_sharding_spec", "distributed/_shard/sharded_optim/test_sharded_optim", "distributed/_shard/test_partial_tensor", "distributed/_shard/test_replicated_tensor", diff --git a/torch/distributed/_shard/sharded_tensor/__init__.py b/torch/distributed/_shard/sharded_tensor/__init__.py index 47a063fce27..2457aa2a1d5 100644 --- a/torch/distributed/_shard/sharded_tensor/__init__.py +++ b/torch/distributed/_shard/sharded_tensor/__init__.py @@ -9,6 +9,7 @@ import torch.distributed._shard.sharding_spec as shard_spec from torch.distributed._shard.partial_tensor import _PartialTensor from .api import ( + _CUSTOM_SHARDED_OPS, _SHARDED_OPS, Shard, ShardedTensor, @@ -411,7 +412,7 @@ def pre_load_state_dict_hook(module, state_dict, prefix, local_metadata, strict, if isinstance(state_dict[key], ShardedTensor): setattr(submodule, attr_name, state_dict[key]) -def sharded_op_impl(func): +def custom_sharded_op_impl(func): """ Provides a way for users to write their own custom sharded operator. This can be used to override existing ShardedTensor operators or write a new @@ -420,7 +421,7 @@ def sharded_op_impl(func): parameters, the function provided will be invoked for that operator. Example:: - >>> @sharded_op_impl(torch.nn.functional.linear) + >>> @custom_sharded_op_impl(torch.nn.functional.linear) >>> def my_custom_sharded_linear(types, args, kwargs, process_group): >>> .... >>> @@ -441,6 +442,16 @@ def sharded_op_impl(func): func(Callable): Torch function for which we want to provide a sharded implementation (ex: torch.nn.functional.linear) """ + return functools.partial( + _decorator_func, + op=func, + op_table=_CUSTOM_SHARDED_OPS + ) + +def _sharded_op_impl(func): + """ + Decorator to register a default sharded op. + """ return functools.partial( _decorator_func, op=func, diff --git a/torch/distributed/_shard/sharded_tensor/_ops/_common.py b/torch/distributed/_shard/sharded_tensor/_ops/_common.py index 1191337e381..3366435f83c 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/_common.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/_common.py @@ -1,6 +1,6 @@ import functools from torch.distributed._shard.sharded_tensor import ( - sharded_op_impl, + _sharded_op_impl, Shard, ShardedTensor, ) @@ -13,7 +13,7 @@ def _sharded_op_common(op, early_stop_func, extra_check): Example:: >>> op = torch.transpose - >>> @sharded_op_impl(op) + >>> @_sharded_op_impl(op) >>> @_sharded_op_common(op, early_stop_func, extra_check) >>> def sharded_tensor_op(types, args, kwargs, process_group): >>> .... @@ -82,7 +82,7 @@ def _register_sharded_op_on_local_shards( func (Callable): registered implementation for sharded op for ``__torch_function__`` dispatch. """ - @sharded_op_impl(op) + @_sharded_op_impl(op) @_sharded_op_common(op, early_stop_func, extra_check) def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None): st = args[0] diff --git a/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py b/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py index cdee16d1890..fa1eded53b5 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py @@ -3,7 +3,7 @@ import torch.distributed as dist import torch.distributed.distributed_c10d as distributed_c10d from torch.distributed._shard.sharded_tensor import ( ShardedTensor, - sharded_op_impl + _sharded_op_impl ) def _communicate_result(result, pg): @@ -59,10 +59,10 @@ def binary_cmp(cmp_fun, types, args, kwargs=None, process_group=None): return _communicate_result(True, st1._process_group) -@sharded_op_impl(torch.equal) +@_sharded_op_impl(torch.equal) def equal(types, args, kwargs, process_group): return binary_cmp(torch.equal, types, args, kwargs, process_group) -@sharded_op_impl(torch.allclose) +@_sharded_op_impl(torch.allclose) def allclose(types, args, kwargs, process_group): return binary_cmp(torch.allclose, types, args, kwargs, process_group) diff --git a/torch/distributed/_shard/sharded_tensor/_ops/chunk.py b/torch/distributed/_shard/sharded_tensor/_ops/chunk.py index d167d378cb9..13548aefd8f 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/chunk.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/chunk.py @@ -1,13 +1,13 @@ import torch from torch.distributed._shard.sharded_tensor import ( - sharded_op_impl, + _sharded_op_impl, ShardedTensor, ) from torch.distributed._shard.sharding_spec import ChunkShardingSpec def register_chunk_op(op): - @sharded_op_impl(op) + @_sharded_op_impl(op) def sharded_chunk(types, args=(), kwargs=None, pg=None): """ Handles ``__torch_function__`` dispatch for the chunk op. diff --git a/torch/distributed/_shard/sharded_tensor/_ops/init.py b/torch/distributed/_shard/sharded_tensor/_ops/init.py index eebd02c8296..df5735b6107 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/init.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/init.py @@ -1,14 +1,14 @@ import torch import torch.distributed._shard.sharded_tensor as sharded_tensor from torch.distributed._shard.sharded_tensor import ( - sharded_op_impl, + _sharded_op_impl, ) def validate_param(param, param_name): if param is None: raise ValueError(f"param: {param_name} shouldn't be None!") -@sharded_op_impl(torch.nn.init.uniform_) +@_sharded_op_impl(torch.nn.init.uniform_) def uniform_(types, args=(), kwargs=None, pg=None): r""" Fills the Tensor in sharded_tensor.local_shards with values drawn from the uniform @@ -30,7 +30,7 @@ def uniform_(types, args=(), kwargs=None, pg=None): torch.nn.init.uniform_(shard.tensor, a=a, b=b) return sharded_tensor -@sharded_op_impl(torch.nn.init.normal_) +@_sharded_op_impl(torch.nn.init.normal_) def normal_(types, args=(), kwargs=None, pg=None): r""" Fills the Tensors in sharded_tensor.local_shards with values drawn from the normal @@ -52,7 +52,7 @@ def normal_(types, args=(), kwargs=None, pg=None): torch.nn.init.normal_(shard.tensor, mean=mean, std=std) return sharded_tensor -@sharded_op_impl(torch.nn.init.kaiming_uniform_) +@_sharded_op_impl(torch.nn.init.kaiming_uniform_) def kaiming_uniform_(types, args=(), kwargs=None, pg=None): r""" Fills the Tensors in sharded_tensor.local_shards with values according to the method @@ -88,7 +88,7 @@ def kaiming_uniform_(types, args=(), kwargs=None, pg=None): torch.nn.init.kaiming_uniform_(shard.tensor, a=a, mode=mode, nonlinearity=nonlinearity) return sharded_tensor -@sharded_op_impl(torch.nn.init.constant_) +@_sharded_op_impl(torch.nn.init.constant_) def constant_(types, args=(), kwargs=None, pg=None): r""" Fills the input ShardedTensor with the value \text{val}val. @@ -116,7 +116,7 @@ tensor_like_creation_op_map = { # tensor ops that behave the same as the default tensor def register_tensor_creation_op(op): - @sharded_op_impl(op) + @_sharded_op_impl(op) def tensor_creation_op(types, args=(), kwargs=None, pg=None): """ Handles ``__torch_function__`` dispatch for tensor creation ops that diff --git a/torch/distributed/_shard/sharded_tensor/_ops/math_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/math_ops.py index f4840864c4c..fa2d30e7e36 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/math_ops.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/math_ops.py @@ -2,7 +2,7 @@ import torch from torch import Tensor from torch.distributed._shard.sharded_tensor import ( ShardedTensor, - sharded_op_impl + _sharded_op_impl ) from torch.distributed._shard.replicated_tensor import ReplicatedTensor from torch.distributed._shard._utils import narrow_tensor @@ -74,7 +74,7 @@ def binary_math_op_impl(op, types, args=(), kwargs=None, pg=None): f"kwargs: {kwargs} not supported yet for ShardedTensor!") def register_math_op(op): - @sharded_op_impl(op) + @_sharded_op_impl(op) def binary_math_op(types, args=(), kwargs=None, pg=None): return binary_math_op_impl(op, types, args, kwargs, pg) diff --git a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py index ae75beca501..84d893d6519 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py @@ -1,7 +1,7 @@ import copy import torch from torch.distributed._shard.sharded_tensor import ( - sharded_op_impl, + _sharded_op_impl, Shard, ShardedTensor, ) @@ -10,7 +10,7 @@ from ._common import ( ) from torch.distributed._shard.common_op_utils import _register_default_op -@sharded_op_impl(torch.Tensor.__deepcopy__) +@_sharded_op_impl(torch.Tensor.__deepcopy__) def tensor_deepcopy(types, args=(), kwargs=None, pg=None): # NOTE: we directly implement deepcopy magic method # instead of using the default tensor.__deepcopy__ @@ -31,18 +31,18 @@ def tensor_deepcopy(types, args=(), kwargs=None, pg=None): # Tensor properties access -_register_default_op(torch.Tensor.requires_grad.__get__, sharded_op_impl) # type: ignore[attr-defined] -_register_default_op(torch.Tensor.shape.__get__, sharded_op_impl) # type: ignore[attr-defined] -_register_default_op(torch.Tensor.dtype.__get__, sharded_op_impl) # type: ignore[attr-defined] -_register_default_op(torch.Tensor.layout.__get__, sharded_op_impl) # type: ignore[attr-defined] -_register_default_op(torch.Tensor.size, sharded_op_impl) -_register_default_op(torch.Tensor.dim, sharded_op_impl) -_register_default_op(torch.Tensor.ndim.__get__, sharded_op_impl) # type: ignore[attr-defined] -_register_default_op(torch.Tensor.is_contiguous, sharded_op_impl) -_register_default_op(torch.Tensor.contiguous, sharded_op_impl) +_register_default_op(torch.Tensor.requires_grad.__get__, _sharded_op_impl) # type: ignore[attr-defined] +_register_default_op(torch.Tensor.shape.__get__, _sharded_op_impl) # type: ignore[attr-defined] +_register_default_op(torch.Tensor.dtype.__get__, _sharded_op_impl) # type: ignore[attr-defined] +_register_default_op(torch.Tensor.layout.__get__, _sharded_op_impl) # type: ignore[attr-defined] +_register_default_op(torch.Tensor.size, _sharded_op_impl) +_register_default_op(torch.Tensor.dim, _sharded_op_impl) +_register_default_op(torch.Tensor.ndim.__get__, _sharded_op_impl) # type: ignore[attr-defined] +_register_default_op(torch.Tensor.is_contiguous, _sharded_op_impl) +_register_default_op(torch.Tensor.contiguous, _sharded_op_impl) # __reduce_ex__ to dispatch to get_state/set_state -_register_default_op(torch.Tensor.__reduce_ex__, sharded_op_impl) +_register_default_op(torch.Tensor.__reduce_ex__, _sharded_op_impl) def sharded_type_as_check(*args, **kwargs): """ @@ -153,7 +153,7 @@ _register_sharded_op_on_local_shards( customized_func=sharded_detach, ) -@sharded_op_impl(torch.Tensor.requires_grad_) +@_sharded_op_impl(torch.Tensor.requires_grad_) def tensor_requires_grad_set(types, args=(), kwargs=None, pg=None): self_st = args[0] requires_grad = args[1] diff --git a/torch/distributed/_shard/sharded_tensor/api.py b/torch/distributed/_shard/sharded_tensor/api.py index 55bb84b0b95..a83f87c8886 100644 --- a/torch/distributed/_shard/sharded_tensor/api.py +++ b/torch/distributed/_shard/sharded_tensor/api.py @@ -11,8 +11,8 @@ from typing import ( cast, ) import copy +from functools import reduce import weakref -import math import threading import torch @@ -49,9 +49,12 @@ _sharded_tensor_lock = threading.Lock() _sharded_tensor_current_id = 0 _sharded_tensor_map: Dict[int, 'weakref.ReferenceType[ShardedTensor]'] = {} -# Custom sharded ops +# Default sharded ops _SHARDED_OPS: Dict[Callable, Callable] = {} +# Customized user ops +_CUSTOM_SHARDED_OPS: Dict[Callable, Callable] = {} + def _register_remote_shards(sharded_tensor_id: int, rrefs: List[rpc.RRef[Shard]], rpc_rank: int): with _sharded_tensor_lock: if sharded_tensor_id not in _sharded_tensor_map: @@ -284,7 +287,7 @@ class ShardedTensor(object): Default: ``None`` """ def shard_size(shard_md): - return math.prod(shard_md.shard_sizes) # type: ignore[attr-defined] + return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined] rank = dist.get_rank(self._process_group) full_size = self.metadata().size @@ -782,6 +785,10 @@ class ShardedTensor(object): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): def dispatch(st: ShardedTensor, func: Callable): + # Dispatch to custom user provided op first if it exists. + if func in _CUSTOM_SHARDED_OPS: + return _CUSTOM_SHARDED_OPS[func](types, args, kwargs, st._process_group) + # Dispatch to custom sharding spec op if it has one. if _has_custom_op(st._sharding_spec, func): return _dispatch_custom_op(