[reland] Add sharding tests to multigpu-test.sh and fix custom operator decorator (#77987)

1. Enabled multigpu tests.
2. Fixed failing multigpu tests.
3. Fixed custom operator decorator to be first preference in operator dispatch.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77987
Approved by: https://github.com/fduwjj, https://github.com/wanchaol, https://github.com/janeyx99
This commit is contained in:
pritam 2022-05-21 22:33:58 +00:00 committed by PyTorch MergeBot
parent 416899d1a9
commit 37eb31599c
12 changed files with 84 additions and 46 deletions

View file

@ -28,4 +28,24 @@ time python test/run_test.py --verbose -i distributed/test_c10d_spawn_nccl
time python test/run_test.py --verbose -i distributed/test_store
time python test/run_test.py --verbose -i distributed/test_pg_wrapper
time python test/run_test.py --verbose -i distributed/rpc/cuda/test_tensorpipe_agent
time python test/run_test.py --verbose -i distributed/_shard/checkpoint/test_checkpoint
time python test/run_test.py --verbose -i distributed/_shard/checkpoint/test_file_system_checkpoint
time python test/run_test.py --verbose -i distributed/_shard/sharding_spec/test_sharding_spec
time python test/run_test.py --verbose -i distributed/_shard/sharding_plan/test_sharding_plan
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/test_megatron_prototype
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/test_sharded_tensor
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/test_sharded_tensor_reshard
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_chunk
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_elementwise_ops
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_embedding
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_embedding_bag
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_binary_cmp
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_init
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_linear
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_math_ops
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_matrix_ops
time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/ops/test_softmax
time python test/run_test.py --verbose -i distributed/_shard/sharded_optim/test_sharded_optim
time python test/run_test.py --verbose -i distributed/_shard/test_partial_tensor
time python test/run_test.py --verbose -i distributed/_shard/test_replicated_tensor
assert_git_not_dirty

View file

@ -101,8 +101,8 @@ class TestShardedTensorMatrixOps(ShardedTensorTestBase):
enumerable_spec, 10, 10, init_rrefs=False, dtype=torch.double
)
with self.assertRaisesRegex(
NotImplementedError,
"Only ChunkShardingSpec supported for 'transpose'",
RuntimeError,
"not supported",
):
st.transpose(1, 0)

View file

@ -19,7 +19,7 @@ from torch.distributed._shard.api import (
_reshard_output,
)
from torch.distributed._shard.sharded_tensor import (
sharded_op_impl,
custom_sharded_op_impl,
pre_load_state_dict_hook,
state_dict_hook,
ShardedTensor,
@ -174,7 +174,7 @@ class TestShardParameter(ShardedTensorTestBase):
with self.assertRaisesRegex(ValueError, 'does not match with src_rank'):
shard_parameter(fc, 'weight', spec, src_rank=self.rank)
with self.assertRaisesRegex(AttributeError, 'Linear have no attribute'):
with self.assertRaisesRegex(AttributeError, 'has no attribute'):
shard_parameter(fc, 'foo', spec)
with self.assertRaisesRegex(ValueError, 'Expected Linear.bias to be a Tensor, but found str'):
@ -2463,7 +2463,7 @@ class TestShardedTensorCustomOps(ShardedTensorTestBase):
@requires_nccl()
def test_custom_op(self):
@sharded_op_impl(torch.asin)
@custom_sharded_op_impl(torch.asin)
def my_sharded_asin(types, args, kwargs, process_group):
return torch.asin(args[0].local_shards()[0].tensor)
@ -2491,7 +2491,7 @@ class TestShardedTensorCustomOps(ShardedTensorTestBase):
from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op
@custom_sharding_spec_op(ChunkShardingSpec, torch.nn.functional.linear)
def my_sharded_linear(types, args, kwargs):
def my_sharded_linear(types, args, kwargs, process_group):
return t
spec = ChunkShardingSpec(
@ -2515,12 +2515,12 @@ class TestShardedTensorCustomOps(ShardedTensorTestBase):
def test_custom_op_errors(self):
with self.assertRaisesRegex(TypeError, 'expects signature'):
@sharded_op_impl(torch.nn.functional.linear)
@custom_sharded_op_impl(torch.nn.functional.linear)
def my_op1(types, args, kwargs, process_group, random_param):
pass
with self.assertRaisesRegex(TypeError, 'expects signature'):
@sharded_op_impl(torch.nn.functional.linear)
@custom_sharded_op_impl(torch.nn.functional.linear)
def my_op2(types):
pass

View file

@ -201,6 +201,8 @@ WINDOWS_BLOCKLIST = [
"distributed/pipeline/sync/test_worker",
"distributed/elastic/agent/server/test/api_test",
"distributed/elastic/multiprocessing/api_test",
"distributed/_shard/checkpoint/test_checkpoint"
"distributed/_shard/checkpoint/test_file_system_checkpoint"
"distributed/_shard/sharding_spec/test_sharding_spec",
"distributed/_shard/sharding_plan/test_sharding_plan",
"distributed/_shard/sharded_tensor/test_megatron_prototype",
@ -216,8 +218,6 @@ WINDOWS_BLOCKLIST = [
"distributed/_shard/sharded_tensor/ops/test_math_ops",
"distributed/_shard/sharded_tensor/ops/test_matrix_ops",
"distributed/_shard/sharded_tensor/ops/test_softmax",
"distributed/_shard/sharded_tensor/ops/test_tensor_ops",
"distributed/_shard/sharding_spec/test_sharding_spec",
"distributed/_shard/sharded_optim/test_sharded_optim",
"distributed/_shard/test_partial_tensor",
"distributed/_shard/test_replicated_tensor",
@ -228,6 +228,8 @@ ROCM_BLOCKLIST = [
"distributed/rpc/test_faulty_agent",
"distributed/rpc/test_tensorpipe_agent",
"distributed/rpc/cuda/test_tensorpipe_agent",
"distributed/_shard/checkpoint/test_checkpoint"
"distributed/_shard/checkpoint/test_file_system_checkpoint"
"distributed/_shard/sharding_spec/test_sharding_spec",
"distributed/_shard/sharding_plan/test_sharding_plan",
"distributed/_shard/sharded_tensor/test_megatron_prototype",
@ -243,8 +245,6 @@ ROCM_BLOCKLIST = [
"distributed/_shard/sharded_tensor/ops/test_math_ops",
"distributed/_shard/sharded_tensor/ops/test_matrix_ops",
"distributed/_shard/sharded_tensor/ops/test_softmax",
"distributed/_shard/sharded_tensor/ops/test_tensor_ops",
"distributed/_shard/sharding_spec/test_sharding_spec",
"distributed/_shard/sharded_optim/test_sharded_optim",
"distributed/_shard/test_partial_tensor",
"distributed/_shard/test_replicated_tensor",

View file

@ -9,6 +9,7 @@ import torch.distributed._shard.sharding_spec as shard_spec
from torch.distributed._shard.partial_tensor import _PartialTensor
from .api import (
_CUSTOM_SHARDED_OPS,
_SHARDED_OPS,
Shard,
ShardedTensor,
@ -411,7 +412,7 @@ def pre_load_state_dict_hook(module, state_dict, prefix, local_metadata, strict,
if isinstance(state_dict[key], ShardedTensor):
setattr(submodule, attr_name, state_dict[key])
def sharded_op_impl(func):
def custom_sharded_op_impl(func):
"""
Provides a way for users to write their own custom sharded operator. This
can be used to override existing ShardedTensor operators or write a new
@ -420,7 +421,7 @@ def sharded_op_impl(func):
parameters, the function provided will be invoked for that operator.
Example::
>>> @sharded_op_impl(torch.nn.functional.linear)
>>> @custom_sharded_op_impl(torch.nn.functional.linear)
>>> def my_custom_sharded_linear(types, args, kwargs, process_group):
>>> ....
>>>
@ -441,6 +442,16 @@ def sharded_op_impl(func):
func(Callable): Torch function for which we want to provide a sharded
implementation (ex: torch.nn.functional.linear)
"""
return functools.partial(
_decorator_func,
op=func,
op_table=_CUSTOM_SHARDED_OPS
)
def _sharded_op_impl(func):
"""
Decorator to register a default sharded op.
"""
return functools.partial(
_decorator_func,
op=func,

View file

@ -1,6 +1,6 @@
import functools
from torch.distributed._shard.sharded_tensor import (
sharded_op_impl,
_sharded_op_impl,
Shard,
ShardedTensor,
)
@ -13,7 +13,7 @@ def _sharded_op_common(op, early_stop_func, extra_check):
Example::
>>> op = torch.transpose
>>> @sharded_op_impl(op)
>>> @_sharded_op_impl(op)
>>> @_sharded_op_common(op, early_stop_func, extra_check)
>>> def sharded_tensor_op(types, args, kwargs, process_group):
>>> ....
@ -82,7 +82,7 @@ def _register_sharded_op_on_local_shards(
func (Callable): registered implementation for sharded op for
``__torch_function__`` dispatch.
"""
@sharded_op_impl(op)
@_sharded_op_impl(op)
@_sharded_op_common(op, early_stop_func, extra_check)
def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None):
st = args[0]

View file

@ -3,7 +3,7 @@ import torch.distributed as dist
import torch.distributed.distributed_c10d as distributed_c10d
from torch.distributed._shard.sharded_tensor import (
ShardedTensor,
sharded_op_impl
_sharded_op_impl
)
def _communicate_result(result, pg):
@ -59,10 +59,10 @@ def binary_cmp(cmp_fun, types, args, kwargs=None, process_group=None):
return _communicate_result(True, st1._process_group)
@sharded_op_impl(torch.equal)
@_sharded_op_impl(torch.equal)
def equal(types, args, kwargs, process_group):
return binary_cmp(torch.equal, types, args, kwargs, process_group)
@sharded_op_impl(torch.allclose)
@_sharded_op_impl(torch.allclose)
def allclose(types, args, kwargs, process_group):
return binary_cmp(torch.allclose, types, args, kwargs, process_group)

View file

@ -1,13 +1,13 @@
import torch
from torch.distributed._shard.sharded_tensor import (
sharded_op_impl,
_sharded_op_impl,
ShardedTensor,
)
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
def register_chunk_op(op):
@sharded_op_impl(op)
@_sharded_op_impl(op)
def sharded_chunk(types, args=(), kwargs=None, pg=None):
"""
Handles ``__torch_function__`` dispatch for the chunk op.

View file

@ -1,14 +1,14 @@
import torch
import torch.distributed._shard.sharded_tensor as sharded_tensor
from torch.distributed._shard.sharded_tensor import (
sharded_op_impl,
_sharded_op_impl,
)
def validate_param(param, param_name):
if param is None:
raise ValueError(f"param: {param_name} shouldn't be None!")
@sharded_op_impl(torch.nn.init.uniform_)
@_sharded_op_impl(torch.nn.init.uniform_)
def uniform_(types, args=(), kwargs=None, pg=None):
r"""
Fills the Tensor in sharded_tensor.local_shards with values drawn from the uniform
@ -30,7 +30,7 @@ def uniform_(types, args=(), kwargs=None, pg=None):
torch.nn.init.uniform_(shard.tensor, a=a, b=b)
return sharded_tensor
@sharded_op_impl(torch.nn.init.normal_)
@_sharded_op_impl(torch.nn.init.normal_)
def normal_(types, args=(), kwargs=None, pg=None):
r"""
Fills the Tensors in sharded_tensor.local_shards with values drawn from the normal
@ -52,7 +52,7 @@ def normal_(types, args=(), kwargs=None, pg=None):
torch.nn.init.normal_(shard.tensor, mean=mean, std=std)
return sharded_tensor
@sharded_op_impl(torch.nn.init.kaiming_uniform_)
@_sharded_op_impl(torch.nn.init.kaiming_uniform_)
def kaiming_uniform_(types, args=(), kwargs=None, pg=None):
r"""
Fills the Tensors in sharded_tensor.local_shards with values according to the method
@ -88,7 +88,7 @@ def kaiming_uniform_(types, args=(), kwargs=None, pg=None):
torch.nn.init.kaiming_uniform_(shard.tensor, a=a, mode=mode, nonlinearity=nonlinearity)
return sharded_tensor
@sharded_op_impl(torch.nn.init.constant_)
@_sharded_op_impl(torch.nn.init.constant_)
def constant_(types, args=(), kwargs=None, pg=None):
r"""
Fills the input ShardedTensor with the value \text{val}val.
@ -116,7 +116,7 @@ tensor_like_creation_op_map = {
# tensor ops that behave the same as the default tensor
def register_tensor_creation_op(op):
@sharded_op_impl(op)
@_sharded_op_impl(op)
def tensor_creation_op(types, args=(), kwargs=None, pg=None):
"""
Handles ``__torch_function__`` dispatch for tensor creation ops that

View file

@ -2,7 +2,7 @@ import torch
from torch import Tensor
from torch.distributed._shard.sharded_tensor import (
ShardedTensor,
sharded_op_impl
_sharded_op_impl
)
from torch.distributed._shard.replicated_tensor import ReplicatedTensor
from torch.distributed._shard._utils import narrow_tensor
@ -74,7 +74,7 @@ def binary_math_op_impl(op, types, args=(), kwargs=None, pg=None):
f"kwargs: {kwargs} not supported yet for ShardedTensor!")
def register_math_op(op):
@sharded_op_impl(op)
@_sharded_op_impl(op)
def binary_math_op(types, args=(), kwargs=None, pg=None):
return binary_math_op_impl(op, types, args, kwargs, pg)

View file

@ -1,7 +1,7 @@
import copy
import torch
from torch.distributed._shard.sharded_tensor import (
sharded_op_impl,
_sharded_op_impl,
Shard,
ShardedTensor,
)
@ -10,7 +10,7 @@ from ._common import (
)
from torch.distributed._shard.common_op_utils import _register_default_op
@sharded_op_impl(torch.Tensor.__deepcopy__)
@_sharded_op_impl(torch.Tensor.__deepcopy__)
def tensor_deepcopy(types, args=(), kwargs=None, pg=None):
# NOTE: we directly implement deepcopy magic method
# instead of using the default tensor.__deepcopy__
@ -31,18 +31,18 @@ def tensor_deepcopy(types, args=(), kwargs=None, pg=None):
# Tensor properties access
_register_default_op(torch.Tensor.requires_grad.__get__, sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.shape.__get__, sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.dtype.__get__, sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.layout.__get__, sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.size, sharded_op_impl)
_register_default_op(torch.Tensor.dim, sharded_op_impl)
_register_default_op(torch.Tensor.ndim.__get__, sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.is_contiguous, sharded_op_impl)
_register_default_op(torch.Tensor.contiguous, sharded_op_impl)
_register_default_op(torch.Tensor.requires_grad.__get__, _sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.shape.__get__, _sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.dtype.__get__, _sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.layout.__get__, _sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.size, _sharded_op_impl)
_register_default_op(torch.Tensor.dim, _sharded_op_impl)
_register_default_op(torch.Tensor.ndim.__get__, _sharded_op_impl) # type: ignore[attr-defined]
_register_default_op(torch.Tensor.is_contiguous, _sharded_op_impl)
_register_default_op(torch.Tensor.contiguous, _sharded_op_impl)
# __reduce_ex__ to dispatch to get_state/set_state
_register_default_op(torch.Tensor.__reduce_ex__, sharded_op_impl)
_register_default_op(torch.Tensor.__reduce_ex__, _sharded_op_impl)
def sharded_type_as_check(*args, **kwargs):
"""
@ -153,7 +153,7 @@ _register_sharded_op_on_local_shards(
customized_func=sharded_detach,
)
@sharded_op_impl(torch.Tensor.requires_grad_)
@_sharded_op_impl(torch.Tensor.requires_grad_)
def tensor_requires_grad_set(types, args=(), kwargs=None, pg=None):
self_st = args[0]
requires_grad = args[1]

View file

@ -11,8 +11,8 @@ from typing import (
cast,
)
import copy
from functools import reduce
import weakref
import math
import threading
import torch
@ -49,9 +49,12 @@ _sharded_tensor_lock = threading.Lock()
_sharded_tensor_current_id = 0
_sharded_tensor_map: Dict[int, 'weakref.ReferenceType[ShardedTensor]'] = {}
# Custom sharded ops
# Default sharded ops
_SHARDED_OPS: Dict[Callable, Callable] = {}
# Customized user ops
_CUSTOM_SHARDED_OPS: Dict[Callable, Callable] = {}
def _register_remote_shards(sharded_tensor_id: int, rrefs: List[rpc.RRef[Shard]], rpc_rank: int):
with _sharded_tensor_lock:
if sharded_tensor_id not in _sharded_tensor_map:
@ -284,7 +287,7 @@ class ShardedTensor(object):
Default: ``None``
"""
def shard_size(shard_md):
return math.prod(shard_md.shard_sizes) # type: ignore[attr-defined]
return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined]
rank = dist.get_rank(self._process_group)
full_size = self.metadata().size
@ -782,6 +785,10 @@ class ShardedTensor(object):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
def dispatch(st: ShardedTensor, func: Callable):
# Dispatch to custom user provided op first if it exists.
if func in _CUSTOM_SHARDED_OPS:
return _CUSTOM_SHARDED_OPS[func](types, args, kwargs, st._process_group)
# Dispatch to custom sharding spec op if it has one.
if _has_custom_op(st._sharding_spec, func):
return _dispatch_custom_op(