mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Pytorch add agent api tests (#56985)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56985 Pytorch add agent api tests Test Plan: ci/cd Reviewed By: cbalioglu Differential Revision: D28020485 fbshipit-source-id: e6acf095f26ce4b99cddfbf7641fb4fa885b0c86
This commit is contained in:
parent
3a923a555a
commit
5c8ceefe46
2 changed files with 18 additions and 17 deletions
|
|
@ -24,7 +24,8 @@ from torch.distributed.elastic.agent.server.api import (
|
|||
)
|
||||
from torch.distributed.elastic.multiprocessing.errors import ProcessFailure
|
||||
from torch.distributed.elastic.rendezvous import RendezvousHandler, RendezvousParameters
|
||||
from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
|
||||
from torch.distributed.elastic.utils.distributed import get_free_port
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
|
||||
def do_nothing():
|
||||
|
|
@ -149,17 +150,6 @@ def monres(state: WorkerState):
|
|||
|
||||
|
||||
class SimpleElasticAgentTest(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# start a standalone, single process etcd server to use for all tests
|
||||
cls._etcd_server = EtcdServer()
|
||||
cls._etcd_server.start()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
# stop the standalone etcd server
|
||||
cls._etcd_server.stop()
|
||||
|
||||
def _get_worker_spec(
|
||||
self,
|
||||
max_restarts=1,
|
||||
|
|
@ -168,10 +158,15 @@ class SimpleElasticAgentTest(unittest.TestCase):
|
|||
local_world_size=8,
|
||||
):
|
||||
run_id = str(uuid.uuid4().int)
|
||||
endpoint = self._etcd_server.get_endpoint()
|
||||
|
||||
port = get_free_port()
|
||||
endpoint = f"127.0.0.1:{port}"
|
||||
rdzv_params = RendezvousParameters(
|
||||
backend="etcd", endpoint=endpoint, run_id=run_id, min_nodes=1, max_nodes=1
|
||||
backend="static",
|
||||
endpoint=endpoint,
|
||||
run_id=run_id,
|
||||
min_nodes=1,
|
||||
max_nodes=1,
|
||||
rank=0,
|
||||
)
|
||||
rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_params)
|
||||
spec = WorkerSpec(
|
||||
|
|
@ -536,7 +531,7 @@ class SimpleElasticAgentTest(unittest.TestCase):
|
|||
agent = TestAgent(spec)
|
||||
actual_event = agent.get_agent_status_event(state=WorkerState.SUCCEEDED)
|
||||
self.assertEqual("AGENT", actual_event.source)
|
||||
self.assertEqual("etcd", actual_event.metadata["rdzv_backend"])
|
||||
self.assertEqual("static", actual_event.metadata["rdzv_backend"])
|
||||
self.assertEqual(WorkerState.SUCCEEDED.value, actual_event.metadata["state"])
|
||||
self.assertEqual(spec.role, actual_event.metadata["role"])
|
||||
|
||||
|
|
@ -550,7 +545,11 @@ class SimpleElasticAgentTest(unittest.TestCase):
|
|||
worker=agent._worker_group.workers[0],
|
||||
)
|
||||
self.assertEqual("WORKER", actual_event.source)
|
||||
self.assertEqual("etcd", actual_event.metadata["rdzv_backend"])
|
||||
self.assertEqual("static", actual_event.metadata["rdzv_backend"])
|
||||
self.assertEqual(WorkerState.SUCCEEDED.value, actual_event.metadata["state"])
|
||||
self.assertEqual(spec.role, actual_event.metadata["role"])
|
||||
self.assertEqual(2, actual_event.metadata["agent_restarts"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -190,6 +190,7 @@ USE_PYTEST_LIST = [
|
|||
'distributions/test_utils',
|
||||
'test_typing',
|
||||
"distributed/elastic/events/lib_test",
|
||||
"distributed/elastic/agent/server/test/api_test",
|
||||
]
|
||||
|
||||
WINDOWS_BLOCKLIST = [
|
||||
|
|
@ -223,6 +224,7 @@ WINDOWS_BLOCKLIST = [
|
|||
'distributed/pipeline/sync/test_transparency',
|
||||
'distributed/pipeline/sync/test_worker',
|
||||
'distributed/optim/test_zero_redundancy_optimizer',
|
||||
"distributed/elastic/agent/server/test/api_test",
|
||||
]
|
||||
|
||||
ROCM_BLOCKLIST = [
|
||||
|
|
|
|||
Loading…
Reference in a new issue