Fix _free_weak_ref error (#78575)

Fixes #74016

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78575
Approved by: https://github.com/ezyang
This commit is contained in:
Kurt Mohler 2022-06-01 00:07:48 +00:00 committed by PyTorch MergeBot
parent 6548f8335d
commit 1705be8ff7
3 changed files with 91 additions and 1 deletions

View file

@ -0,0 +1,85 @@
#!/usr/bin/env python3
# Owner(s): ["oncall: distributed"]
import torch
import torch.distributed as dist
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
import copyreg
import os
import contextlib
from torch import multiprocessing
import torch.multiprocessing.reductions as TorchMpReductions
import torch.distributed.rpc as rpc
from torch.distributed.rpc.internal import _InternalRPCPickler
from torch.distributed.rpc.api import _use_rpc_pickler
from torch.testing._internal.common_utils import TestCase, run_tests
@contextlib.contextmanager
def fs_sharing():
prev_strategy = multiprocessing.get_sharing_strategy()
multiprocessing.set_sharing_strategy('file_system')
try:
yield
finally:
multiprocessing.set_sharing_strategy(prev_strategy)
class ShareMemoryRPCPickler(_InternalRPCPickler):
def __init__(self) -> None:
super().__init__()
self._dispatch_table
# pyre-fixme[4]: Attribute must be annotated.
self._dispatch_table = copyreg.dispatch_table.copy()
for t in torch._storage_classes:
self._dispatch_table[t] = TorchMpReductions.reduce_storage
for t in torch._tensor_classes:
self._dispatch_table[t] = TorchMpReductions.reduce_tensor
self._dispatch_table[torch.Tensor] = TorchMpReductions.reduce_tensor
self._dispatch_table[
torch.nn.parameter.Parameter
] = TorchMpReductions.reduce_tensor
def worker_loop(a):
rpc.init_rpc('worker1', rank=1, world_size=2)
rpc.shutdown()
def worker_fn(m):
pass
class TestRPCPickler(TestCase):
def setUp(self):
super().setUp()
def test_case(self):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
with fs_sharing():
r = multiprocessing.spawn(worker_loop, join=False)
try:
with _use_rpc_pickler(ShareMemoryRPCPickler()):
rpc.init_rpc(
'worker0',
rank=0,
world_size=2)
m = torch.nn.Linear(1, 2)
m.share_memory()
rref = rpc.remote(
'worker1',
worker_fn,
args=(m,))
rref.to_here()
finally:
rpc.shutdown()
r.join()
if __name__ == '__main__':
run_tests()

View file

@ -176,6 +176,7 @@ WINDOWS_BLOCKLIST = [
"distributed/nn/jit/test_instantiator",
"distributed/rpc/test_faulty_agent",
"distributed/rpc/test_tensorpipe_agent",
"distributed/rpc/test_share_memory",
"distributed/rpc/cuda/test_tensorpipe_agent",
"distributed/pipeline/sync/skip/test_api",
"distributed/pipeline/sync/skip/test_gpipe",
@ -227,6 +228,7 @@ ROCM_BLOCKLIST = [
"distributed/nn/jit/test_instantiator",
"distributed/rpc/test_faulty_agent",
"distributed/rpc/test_tensorpipe_agent",
"distributed/rpc/test_share_memory",
"distributed/rpc/cuda/test_tensorpipe_agent",
"distributed/_shard/checkpoint/test_checkpoint"
"distributed/_shard/checkpoint/test_file_system_checkpoint"
@ -612,6 +614,7 @@ CUSTOM_HANDLERS = {
"distributed/test_pg_wrapper": get_run_test_with_subprocess_fn(),
"distributed/rpc/test_faulty_agent": get_run_test_with_subprocess_fn(),
"distributed/rpc/test_tensorpipe_agent": get_run_test_with_subprocess_fn(),
"distributed/rpc/test_share_memory": get_run_test_with_subprocess_fn(),
"distributed/rpc/cuda/test_tensorpipe_agent": get_run_test_with_subprocess_fn(),
}

View file

@ -63,6 +63,8 @@ class _StorageBase(object):
@classmethod
def _new_shared_cuda(cls, *args, **kwargs) -> T: ... # noqa: E704
def _shared_incref(self, *args, **kwargs): ... # noqa: E704
@classmethod
def _free_weak_ref(cls, *args, **kwargs): ... # noqa: E704
def __str__(self):
info_str = (
@ -649,7 +651,7 @@ class _TypedStorage:
@classmethod
def _free_weak_ref(cls, *args, **kwargs):
return eval(cls.__module__)._UntypedStorage._free_weak_ref(*args, **kwargs)
return _UntypedStorage._free_weak_ref(*args, **kwargs)
def _weak_ref(self, *args, **kwargs):
return self._storage._weak_ref(*args, **kwargs)