mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Fix _free_weak_ref error (#78575)
Fixes #74016 Pull Request resolved: https://github.com/pytorch/pytorch/pull/78575 Approved by: https://github.com/ezyang
This commit is contained in:
parent
6548f8335d
commit
1705be8ff7
3 changed files with 91 additions and 1 deletions
85
test/distributed/rpc/test_share_memory.py
Normal file
85
test/distributed/rpc/test_share_memory.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
#!/usr/bin/env python3
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
import copyreg
|
||||
import os
|
||||
import contextlib
|
||||
|
||||
from torch import multiprocessing
|
||||
import torch.multiprocessing.reductions as TorchMpReductions
|
||||
import torch.distributed.rpc as rpc
|
||||
from torch.distributed.rpc.internal import _InternalRPCPickler
|
||||
from torch.distributed.rpc.api import _use_rpc_pickler
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
|
||||
@contextlib.contextmanager
|
||||
def fs_sharing():
|
||||
prev_strategy = multiprocessing.get_sharing_strategy()
|
||||
multiprocessing.set_sharing_strategy('file_system')
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
multiprocessing.set_sharing_strategy(prev_strategy)
|
||||
|
||||
class ShareMemoryRPCPickler(_InternalRPCPickler):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._dispatch_table
|
||||
# pyre-fixme[4]: Attribute must be annotated.
|
||||
self._dispatch_table = copyreg.dispatch_table.copy()
|
||||
|
||||
for t in torch._storage_classes:
|
||||
self._dispatch_table[t] = TorchMpReductions.reduce_storage
|
||||
|
||||
for t in torch._tensor_classes:
|
||||
self._dispatch_table[t] = TorchMpReductions.reduce_tensor
|
||||
self._dispatch_table[torch.Tensor] = TorchMpReductions.reduce_tensor
|
||||
self._dispatch_table[
|
||||
torch.nn.parameter.Parameter
|
||||
] = TorchMpReductions.reduce_tensor
|
||||
|
||||
def worker_loop(a):
|
||||
rpc.init_rpc('worker1', rank=1, world_size=2)
|
||||
rpc.shutdown()
|
||||
|
||||
def worker_fn(m):
|
||||
pass
|
||||
|
||||
class TestRPCPickler(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
def test_case(self):
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '29500'
|
||||
|
||||
with fs_sharing():
|
||||
r = multiprocessing.spawn(worker_loop, join=False)
|
||||
|
||||
try:
|
||||
with _use_rpc_pickler(ShareMemoryRPCPickler()):
|
||||
rpc.init_rpc(
|
||||
'worker0',
|
||||
rank=0,
|
||||
world_size=2)
|
||||
m = torch.nn.Linear(1, 2)
|
||||
m.share_memory()
|
||||
rref = rpc.remote(
|
||||
'worker1',
|
||||
worker_fn,
|
||||
args=(m,))
|
||||
|
||||
rref.to_here()
|
||||
finally:
|
||||
rpc.shutdown()
|
||||
r.join()
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
@ -176,6 +176,7 @@ WINDOWS_BLOCKLIST = [
|
|||
"distributed/nn/jit/test_instantiator",
|
||||
"distributed/rpc/test_faulty_agent",
|
||||
"distributed/rpc/test_tensorpipe_agent",
|
||||
"distributed/rpc/test_share_memory",
|
||||
"distributed/rpc/cuda/test_tensorpipe_agent",
|
||||
"distributed/pipeline/sync/skip/test_api",
|
||||
"distributed/pipeline/sync/skip/test_gpipe",
|
||||
|
|
@ -227,6 +228,7 @@ ROCM_BLOCKLIST = [
|
|||
"distributed/nn/jit/test_instantiator",
|
||||
"distributed/rpc/test_faulty_agent",
|
||||
"distributed/rpc/test_tensorpipe_agent",
|
||||
"distributed/rpc/test_share_memory",
|
||||
"distributed/rpc/cuda/test_tensorpipe_agent",
|
||||
"distributed/_shard/checkpoint/test_checkpoint"
|
||||
"distributed/_shard/checkpoint/test_file_system_checkpoint"
|
||||
|
|
@ -612,6 +614,7 @@ CUSTOM_HANDLERS = {
|
|||
"distributed/test_pg_wrapper": get_run_test_with_subprocess_fn(),
|
||||
"distributed/rpc/test_faulty_agent": get_run_test_with_subprocess_fn(),
|
||||
"distributed/rpc/test_tensorpipe_agent": get_run_test_with_subprocess_fn(),
|
||||
"distributed/rpc/test_share_memory": get_run_test_with_subprocess_fn(),
|
||||
"distributed/rpc/cuda/test_tensorpipe_agent": get_run_test_with_subprocess_fn(),
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -63,6 +63,8 @@ class _StorageBase(object):
|
|||
@classmethod
|
||||
def _new_shared_cuda(cls, *args, **kwargs) -> T: ... # noqa: E704
|
||||
def _shared_incref(self, *args, **kwargs): ... # noqa: E704
|
||||
@classmethod
|
||||
def _free_weak_ref(cls, *args, **kwargs): ... # noqa: E704
|
||||
|
||||
def __str__(self):
|
||||
info_str = (
|
||||
|
|
@ -649,7 +651,7 @@ class _TypedStorage:
|
|||
|
||||
@classmethod
|
||||
def _free_weak_ref(cls, *args, **kwargs):
|
||||
return eval(cls.__module__)._UntypedStorage._free_weak_ref(*args, **kwargs)
|
||||
return _UntypedStorage._free_weak_ref(*args, **kwargs)
|
||||
|
||||
def _weak_ref(self, *args, **kwargs):
|
||||
return self._storage._weak_ref(*args, **kwargs)
|
||||
|
|
|
|||
Loading…
Reference in a new issue