Clang format dist_utils.py and rpc/__init__.py (#56853)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56853

ghstack-source-id: 127412640

Test Plan: N/A

Reviewed By: rohan-varma

Differential Revision: D27984669

fbshipit-source-id: 8e89ba0c53107622b3ca29ea296226e260b251df
This commit is contained in:
Yi Wang 2021-04-26 11:32:43 -07:00 committed by Facebook GitHub Bot
parent 6155b0d9fa
commit 7989f2ac87
2 changed files with 49 additions and 29 deletions

View file

@ -1,8 +1,8 @@
import logging
import threading
import warnings
from typing import Generator, Tuple
import torch
import torch.distributed as dist
@ -13,6 +13,7 @@ logger = logging.getLogger(__name__)
_init_counter = 0
_init_counter_lock = threading.Lock()
def is_available():
return hasattr(torch._C, "_rpc_init")
@ -22,7 +23,7 @@ if is_available() and not torch._C._rpc_init():
if is_available():
from . import api, backend_registry, functions
from torch._C._distributed_c10d import Store
from torch._C._distributed_rpc import (
_disable_jit_rref_pickle,
_enable_jit_rref_pickle,
@ -61,16 +62,18 @@ if is_available():
_UNSET_RPC_TIMEOUT,
_DEFAULT_RPC_TIMEOUT_SEC,
) # noqa: F401
from torch._C._distributed_c10d import Store
from . import api, backend_registry, functions
from .api import * # noqa: F401,F403
from .options import TensorPipeRpcBackendOptions # noqa: F401
import numbers
import torch.distributed.autograd as dist_autograd
from .backend_registry import BackendType
from .options import TensorPipeRpcBackendOptions # noqa: F401
from .server_process_global_profiler import (
_server_process_global_profile,
)
import torch.distributed.autograd as dist_autograd
import numbers
rendezvous_iterator: Generator[Tuple[Store, int, int], None, None]
@ -111,12 +114,14 @@ if is_available():
are available.
"""
if backend is not None and not isinstance(backend, backend_registry.BackendType):
raise TypeError(
"Argument backend must be a member of BackendType"
)
if backend is not None and not isinstance(
backend, backend_registry.BackendType
):
raise TypeError("Argument backend must be a member of BackendType")
if rpc_backend_options is not None and not isinstance(rpc_backend_options, RpcBackendOptions):
if rpc_backend_options is not None and not isinstance(
rpc_backend_options, RpcBackendOptions
):
raise TypeError(
"Argument rpc_backend_options must be an instance of RpcBackendOptions"
)
@ -182,7 +187,7 @@ if is_available():
# Use a PrefixStore to distinguish multiple invocations.
with _init_counter_lock:
global _init_counter
store = dist.PrefixStore(str('rpc_prefix_{}'.format(_init_counter)), store)
store = dist.PrefixStore(str("rpc_prefix_{}".format(_init_counter)), store)
_init_counter += 1
# Initialize autograd before RPC since _init_rpc_backend guarantees all
@ -197,7 +202,6 @@ if is_available():
# Initialize RPC.
_init_rpc_backend(backend, store, name, rank, world_size, rpc_backend_options)
def _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options):
type_mapping = {
backend: backend_registry.BackendType,
@ -215,7 +219,6 @@ if is_available():
)
)
def _init_rpc_backend(
backend=BackendType.TENSORPIPE, # type: ignore[attr-defined]
store=None,
@ -242,7 +245,6 @@ if is_available():
api._init_rpc_states(rpc_agent)
@api._require_initialized
def _get_debug_info():
info = _rref_context_get_debug_info()

View file

@ -1,4 +1,3 @@
import re
import sys
import time
@ -24,22 +23,31 @@ def single_threaded_process_group_agent(f):
Forces ProcessGroupAgent to use only a single thread in the ThreadPool for
sending and processing requests.
"""
@wraps(f)
def wrapper(self, *args, **kwargs):
backend_type = self.rpc_backend
if backend_type == rpc.backend_registry.BackendType["PROCESS_GROUP"]:
self.rpc_backend_options = rpc.backend_registry.construct_rpc_backend_options(
self.rpc_backend,
init_method=self.init_method,
num_send_recv_threads=1,
self.rpc_backend_options = (
rpc.backend_registry.construct_rpc_backend_options(
self.rpc_backend,
init_method=self.init_method,
num_send_recv_threads=1,
)
)
return_value = f(self, *args, **kwargs)
return return_value
return wrapper
def dist_init(old_test_method=None, setup_rpc: bool = True, clean_shutdown: bool = True,
faulty_messages=None, messages_to_delay=None):
def dist_init(
old_test_method=None,
setup_rpc: bool = True,
clean_shutdown: bool = True,
faulty_messages=None,
messages_to_delay=None,
):
"""
We use this decorator for setting up and tearing down state since
MultiProcessTestCase runs each `test*` method in a separate process and
@ -73,6 +81,7 @@ def dist_init(old_test_method=None, setup_rpc: bool = True, clean_shutdown: bool
# Setting _ignore_rref_leak to make sure OwnerRRefs are properly deleted
# in tests.
import torch.distributed.rpc.api as api
api._ignore_rref_leak = False
self.worker_id = self.rank
@ -101,15 +110,16 @@ def dist_init(old_test_method=None, setup_rpc: bool = True, clean_shutdown: bool
def noop() -> None:
pass
def wait_until_node_failure(rank: int, expected_error_regex: str = ".*") -> str:
'''
"""
Loops until an RPC to the given rank fails. This is used to
indicate that the node has failed in unit tests.
Args:
rank (int): Rank of the node expected to fail
expected_error_regex (optional, str): Regex of exception message expected. Useful to ensure a specific failure
occurs, not just any.
'''
"""
while True:
try:
rpc.rpc_sync("worker{}".format(rank), noop, args=())
@ -120,7 +130,7 @@ def wait_until_node_failure(rank: int, expected_error_regex: str = ".*") -> str:
def wait_until_pending_futures_and_users_flushed(timeout: int = 20) -> None:
'''
"""
The RRef protocol holds forkIds of rrefs in a map until those forks are
confirmed by the owner. The message confirming the fork may arrive after
our tests check whether this map is empty, which leads to failures and
@ -129,7 +139,7 @@ def wait_until_pending_futures_and_users_flushed(timeout: int = 20) -> None:
loops until the map is empty, which means the messages have been received
as processed. Call this function before asserting the map returned by
_get_debug_info is empty.
'''
"""
start = time.time()
while True:
debug_info = _rref_context_get_debug_info()
@ -157,7 +167,9 @@ def get_num_owners_and_forks() -> Tuple[str, str]:
return num_owners, num_forks
def wait_until_owners_and_forks_on_rank(num_owners: int, num_forks: int, rank: int, timeout: int = 20) -> None:
def wait_until_owners_and_forks_on_rank(
num_owners: int, num_forks: int, rank: int, timeout: int = 20
) -> None:
"""
Waits until timeout for num_forks and num_owners to exist on the rank. Used
to ensure proper deletion of RRefs in tests.
@ -175,7 +187,11 @@ def wait_until_owners_and_forks_on_rank(num_owners: int, num_forks: int, rank: i
if time.time() - start > timeout:
raise ValueError(
"Timed out waiting {} sec for {} owners and {} forks on rank, had {} owners and {} forks".format(
timeout, num_owners, num_forks, num_owners_on_rank, num_forks_on_rank
timeout,
num_owners,
num_forks,
num_owners_on_rank,
num_forks_on_rank,
)
)
@ -192,9 +208,11 @@ def initialize_pg(init_method, rank: int, world_size: int) -> None:
world_size=world_size,
)
def worker_name(rank: int) -> str:
return "worker{}".format(rank)
def get_function_event(function_events, partial_event_name):
"""
Returns the first event that matches partial_event_name in the provided