mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Action following discussion with distributed and r2p team--the tests under elastic in distributed should be owned by oncall: r2p and not distributed. cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang Pull Request resolved: https://github.com/pytorch/pytorch/pull/67293 Reviewed By: jbschlosser Differential Revision: D31973779 Pulled By: janeyx99 fbshipit-source-id: 05875a7600c6eb1da1310a48e1e32a1a69461c55
76 lines
2.4 KiB
Python
76 lines
2.4 KiB
Python
# Owner(s): ["oncall: r2p"]
|
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
import unittest
|
|
import unittest.mock as mock
|
|
|
|
from torch.distributed.elastic.timer import TimerServer
|
|
from torch.distributed.elastic.timer.api import RequestQueue, TimerRequest
|
|
|
|
|
|
class MockRequestQueue(RequestQueue):
|
|
def size(self):
|
|
return 2
|
|
|
|
def get(self, size, timeout):
|
|
return [TimerRequest(1, "test_1", 0), TimerRequest(2, "test_2", 0)]
|
|
|
|
|
|
class MockTimerServer(TimerServer):
|
|
"""
|
|
Mock implementation of TimerServer for testing purposes.
|
|
This mock has the following behavior:
|
|
|
|
1. reaping worker 1 throws
|
|
2. reaping worker 2 succeeds
|
|
3. reaping worker 3 fails (caught exception)
|
|
|
|
For each workers 1 - 3 returns 2 expired timers
|
|
"""
|
|
|
|
def __init__(self, request_queue, max_interval):
|
|
super().__init__(request_queue, max_interval)
|
|
|
|
def register_timers(self, timer_requests):
|
|
pass
|
|
|
|
def clear_timers(self, worker_ids):
|
|
pass
|
|
|
|
def get_expired_timers(self, deadline):
|
|
return {
|
|
i: [TimerRequest(i, f"test_{i}_0", 0), TimerRequest(i, f"test_{i}_1", 0)]
|
|
for i in range(1, 4)
|
|
}
|
|
|
|
def _reap_worker(self, worker_id):
|
|
if worker_id == 1:
|
|
raise RuntimeError("test error")
|
|
elif worker_id == 2:
|
|
return True
|
|
elif worker_id == 3:
|
|
return False
|
|
|
|
|
|
class TimerApiTest(unittest.TestCase):
|
|
@mock.patch.object(MockTimerServer, "register_timers")
|
|
@mock.patch.object(MockTimerServer, "clear_timers")
|
|
def test_run_watchdog(self, mock_clear_timers, mock_register_timers):
|
|
"""
|
|
tests that when a ``_reap_worker()`` method throws an exception
|
|
for a particular worker_id, the timers for successfully reaped workers
|
|
are cleared properly
|
|
"""
|
|
max_interval = 1
|
|
request_queue = mock.Mock(wraps=MockRequestQueue())
|
|
timer_server = MockTimerServer(request_queue, max_interval)
|
|
timer_server._run_watchdog()
|
|
|
|
request_queue.size.assert_called_once()
|
|
request_queue.get.assert_called_with(request_queue.size(), max_interval)
|
|
mock_register_timers.assert_called_with(request_queue.get(2, 1))
|
|
mock_clear_timers.assert_called_with({1, 2})
|