mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Clang format dist_utils.py and rpc/__init__.py (#56853)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56853 ghstack-source-id: 127412640 Test Plan: N/A Reviewed By: rohan-varma Differential Revision: D27984669 fbshipit-source-id: 8e89ba0c53107622b3ca29ea296226e260b251df
This commit is contained in:
parent
6155b0d9fa
commit
7989f2ac87
2 changed files with 49 additions and 29 deletions
|
|
@ -1,8 +1,8 @@
|
|||
import logging
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
from typing import Generator, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
|
@ -13,6 +13,7 @@ logger = logging.getLogger(__name__)
|
|||
_init_counter = 0
|
||||
_init_counter_lock = threading.Lock()
|
||||
|
||||
|
||||
def is_available():
|
||||
return hasattr(torch._C, "_rpc_init")
|
||||
|
||||
|
|
@ -22,7 +23,7 @@ if is_available() and not torch._C._rpc_init():
|
|||
|
||||
|
||||
if is_available():
|
||||
from . import api, backend_registry, functions
|
||||
from torch._C._distributed_c10d import Store
|
||||
from torch._C._distributed_rpc import (
|
||||
_disable_jit_rref_pickle,
|
||||
_enable_jit_rref_pickle,
|
||||
|
|
@ -61,16 +62,18 @@ if is_available():
|
|||
_UNSET_RPC_TIMEOUT,
|
||||
_DEFAULT_RPC_TIMEOUT_SEC,
|
||||
) # noqa: F401
|
||||
from torch._C._distributed_c10d import Store
|
||||
|
||||
from . import api, backend_registry, functions
|
||||
from .api import * # noqa: F401,F403
|
||||
from .options import TensorPipeRpcBackendOptions # noqa: F401
|
||||
import numbers
|
||||
|
||||
import torch.distributed.autograd as dist_autograd
|
||||
|
||||
from .backend_registry import BackendType
|
||||
from .options import TensorPipeRpcBackendOptions # noqa: F401
|
||||
from .server_process_global_profiler import (
|
||||
_server_process_global_profile,
|
||||
)
|
||||
import torch.distributed.autograd as dist_autograd
|
||||
|
||||
import numbers
|
||||
|
||||
rendezvous_iterator: Generator[Tuple[Store, int, int], None, None]
|
||||
|
||||
|
|
@ -111,12 +114,14 @@ if is_available():
|
|||
are available.
|
||||
"""
|
||||
|
||||
if backend is not None and not isinstance(backend, backend_registry.BackendType):
|
||||
raise TypeError(
|
||||
"Argument backend must be a member of BackendType"
|
||||
)
|
||||
if backend is not None and not isinstance(
|
||||
backend, backend_registry.BackendType
|
||||
):
|
||||
raise TypeError("Argument backend must be a member of BackendType")
|
||||
|
||||
if rpc_backend_options is not None and not isinstance(rpc_backend_options, RpcBackendOptions):
|
||||
if rpc_backend_options is not None and not isinstance(
|
||||
rpc_backend_options, RpcBackendOptions
|
||||
):
|
||||
raise TypeError(
|
||||
"Argument rpc_backend_options must be an instance of RpcBackendOptions"
|
||||
)
|
||||
|
|
@ -182,7 +187,7 @@ if is_available():
|
|||
# Use a PrefixStore to distinguish multiple invocations.
|
||||
with _init_counter_lock:
|
||||
global _init_counter
|
||||
store = dist.PrefixStore(str('rpc_prefix_{}'.format(_init_counter)), store)
|
||||
store = dist.PrefixStore(str("rpc_prefix_{}".format(_init_counter)), store)
|
||||
_init_counter += 1
|
||||
|
||||
# Initialize autograd before RPC since _init_rpc_backend guarantees all
|
||||
|
|
@ -197,7 +202,6 @@ if is_available():
|
|||
# Initialize RPC.
|
||||
_init_rpc_backend(backend, store, name, rank, world_size, rpc_backend_options)
|
||||
|
||||
|
||||
def _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options):
|
||||
type_mapping = {
|
||||
backend: backend_registry.BackendType,
|
||||
|
|
@ -215,7 +219,6 @@ if is_available():
|
|||
)
|
||||
)
|
||||
|
||||
|
||||
def _init_rpc_backend(
|
||||
backend=BackendType.TENSORPIPE, # type: ignore[attr-defined]
|
||||
store=None,
|
||||
|
|
@ -242,7 +245,6 @@ if is_available():
|
|||
|
||||
api._init_rpc_states(rpc_agent)
|
||||
|
||||
|
||||
@api._require_initialized
|
||||
def _get_debug_info():
|
||||
info = _rref_context_get_debug_info()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
|
|
@ -24,22 +23,31 @@ def single_threaded_process_group_agent(f):
|
|||
Forces ProcessGroupAgent to use only a single thread in the ThreadPool for
|
||||
sending and processing requests.
|
||||
"""
|
||||
|
||||
@wraps(f)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
backend_type = self.rpc_backend
|
||||
if backend_type == rpc.backend_registry.BackendType["PROCESS_GROUP"]:
|
||||
self.rpc_backend_options = rpc.backend_registry.construct_rpc_backend_options(
|
||||
self.rpc_backend,
|
||||
init_method=self.init_method,
|
||||
num_send_recv_threads=1,
|
||||
self.rpc_backend_options = (
|
||||
rpc.backend_registry.construct_rpc_backend_options(
|
||||
self.rpc_backend,
|
||||
init_method=self.init_method,
|
||||
num_send_recv_threads=1,
|
||||
)
|
||||
)
|
||||
return_value = f(self, *args, **kwargs)
|
||||
return return_value
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def dist_init(old_test_method=None, setup_rpc: bool = True, clean_shutdown: bool = True,
|
||||
faulty_messages=None, messages_to_delay=None):
|
||||
def dist_init(
|
||||
old_test_method=None,
|
||||
setup_rpc: bool = True,
|
||||
clean_shutdown: bool = True,
|
||||
faulty_messages=None,
|
||||
messages_to_delay=None,
|
||||
):
|
||||
"""
|
||||
We use this decorator for setting up and tearing down state since
|
||||
MultiProcessTestCase runs each `test*` method in a separate process and
|
||||
|
|
@ -73,6 +81,7 @@ def dist_init(old_test_method=None, setup_rpc: bool = True, clean_shutdown: bool
|
|||
# Setting _ignore_rref_leak to make sure OwnerRRefs are properly deleted
|
||||
# in tests.
|
||||
import torch.distributed.rpc.api as api
|
||||
|
||||
api._ignore_rref_leak = False
|
||||
|
||||
self.worker_id = self.rank
|
||||
|
|
@ -101,15 +110,16 @@ def dist_init(old_test_method=None, setup_rpc: bool = True, clean_shutdown: bool
|
|||
def noop() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def wait_until_node_failure(rank: int, expected_error_regex: str = ".*") -> str:
|
||||
'''
|
||||
"""
|
||||
Loops until an RPC to the given rank fails. This is used to
|
||||
indicate that the node has failed in unit tests.
|
||||
Args:
|
||||
rank (int): Rank of the node expected to fail
|
||||
expected_error_regex (optional, str): Regex of exception message expected. Useful to ensure a specific failure
|
||||
occurs, not just any.
|
||||
'''
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
rpc.rpc_sync("worker{}".format(rank), noop, args=())
|
||||
|
|
@ -120,7 +130,7 @@ def wait_until_node_failure(rank: int, expected_error_regex: str = ".*") -> str:
|
|||
|
||||
|
||||
def wait_until_pending_futures_and_users_flushed(timeout: int = 20) -> None:
|
||||
'''
|
||||
"""
|
||||
The RRef protocol holds forkIds of rrefs in a map until those forks are
|
||||
confirmed by the owner. The message confirming the fork may arrive after
|
||||
our tests check whether this map is empty, which leads to failures and
|
||||
|
|
@ -129,7 +139,7 @@ def wait_until_pending_futures_and_users_flushed(timeout: int = 20) -> None:
|
|||
loops until the map is empty, which means the messages have been received
|
||||
as processed. Call this function before asserting the map returned by
|
||||
_get_debug_info is empty.
|
||||
'''
|
||||
"""
|
||||
start = time.time()
|
||||
while True:
|
||||
debug_info = _rref_context_get_debug_info()
|
||||
|
|
@ -157,7 +167,9 @@ def get_num_owners_and_forks() -> Tuple[str, str]:
|
|||
return num_owners, num_forks
|
||||
|
||||
|
||||
def wait_until_owners_and_forks_on_rank(num_owners: int, num_forks: int, rank: int, timeout: int = 20) -> None:
|
||||
def wait_until_owners_and_forks_on_rank(
|
||||
num_owners: int, num_forks: int, rank: int, timeout: int = 20
|
||||
) -> None:
|
||||
"""
|
||||
Waits until timeout for num_forks and num_owners to exist on the rank. Used
|
||||
to ensure proper deletion of RRefs in tests.
|
||||
|
|
@ -175,7 +187,11 @@ def wait_until_owners_and_forks_on_rank(num_owners: int, num_forks: int, rank: i
|
|||
if time.time() - start > timeout:
|
||||
raise ValueError(
|
||||
"Timed out waiting {} sec for {} owners and {} forks on rank, had {} owners and {} forks".format(
|
||||
timeout, num_owners, num_forks, num_owners_on_rank, num_forks_on_rank
|
||||
timeout,
|
||||
num_owners,
|
||||
num_forks,
|
||||
num_owners_on_rank,
|
||||
num_forks_on_rank,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -192,9 +208,11 @@ def initialize_pg(init_method, rank: int, world_size: int) -> None:
|
|||
world_size=world_size,
|
||||
)
|
||||
|
||||
|
||||
def worker_name(rank: int) -> str:
|
||||
return "worker{}".format(rank)
|
||||
|
||||
|
||||
def get_function_event(function_events, partial_event_name):
|
||||
"""
|
||||
Returns the first event that matches partial_event_name in the provided
|
||||
|
|
|
|||
Loading…
Reference in a new issue