diff --git a/test/distributed/elastic/agent/server/test/api_test.py b/test/distributed/elastic/agent/server/test/api_test.py index 9eaec3ec4ed..7d1155437a8 100644 --- a/test/distributed/elastic/agent/server/test/api_test.py +++ b/test/distributed/elastic/agent/server/test/api_test.py @@ -24,7 +24,8 @@ from torch.distributed.elastic.agent.server.api import ( ) from torch.distributed.elastic.multiprocessing.errors import ProcessFailure from torch.distributed.elastic.rendezvous import RendezvousHandler, RendezvousParameters -from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer +from torch.distributed.elastic.utils.distributed import get_free_port +from torch.testing._internal.common_utils import run_tests def do_nothing(): @@ -149,17 +150,6 @@ def monres(state: WorkerState): class SimpleElasticAgentTest(unittest.TestCase): - @classmethod - def setUpClass(cls): - # start a standalone, single process etcd server to use for all tests - cls._etcd_server = EtcdServer() - cls._etcd_server.start() - - @classmethod - def tearDownClass(cls): - # stop the standalone etcd server - cls._etcd_server.stop() - def _get_worker_spec( self, max_restarts=1, @@ -168,10 +158,15 @@ class SimpleElasticAgentTest(unittest.TestCase): local_world_size=8, ): run_id = str(uuid.uuid4().int) - endpoint = self._etcd_server.get_endpoint() - + port = get_free_port() + endpoint = f"127.0.0.1:{port}" rdzv_params = RendezvousParameters( - backend="etcd", endpoint=endpoint, run_id=run_id, min_nodes=1, max_nodes=1 + backend="static", + endpoint=endpoint, + run_id=run_id, + min_nodes=1, + max_nodes=1, + rank=0, ) rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_params) spec = WorkerSpec( @@ -536,7 +531,7 @@ class SimpleElasticAgentTest(unittest.TestCase): agent = TestAgent(spec) actual_event = agent.get_agent_status_event(state=WorkerState.SUCCEEDED) self.assertEqual("AGENT", actual_event.source) - self.assertEqual("etcd", actual_event.metadata["rdzv_backend"]) + self.assertEqual("static", actual_event.metadata["rdzv_backend"]) self.assertEqual(WorkerState.SUCCEEDED.value, actual_event.metadata["state"]) self.assertEqual(spec.role, actual_event.metadata["role"]) @@ -550,7 +545,11 @@ class SimpleElasticAgentTest(unittest.TestCase): worker=agent._worker_group.workers[0], ) self.assertEqual("WORKER", actual_event.source) - self.assertEqual("etcd", actual_event.metadata["rdzv_backend"]) + self.assertEqual("static", actual_event.metadata["rdzv_backend"]) self.assertEqual(WorkerState.SUCCEEDED.value, actual_event.metadata["state"]) self.assertEqual(spec.role, actual_event.metadata["role"]) self.assertEqual(2, actual_event.metadata["agent_restarts"]) + + +if __name__ == "__main__": + run_tests() diff --git a/test/run_test.py b/test/run_test.py index 9e8f0a5f7de..cb3fbdce34f 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -190,6 +190,7 @@ USE_PYTEST_LIST = [ 'distributions/test_utils', 'test_typing', "distributed/elastic/events/lib_test", + "distributed/elastic/agent/server/test/api_test", ] WINDOWS_BLOCKLIST = [ @@ -223,6 +224,7 @@ WINDOWS_BLOCKLIST = [ 'distributed/pipeline/sync/test_transparency', 'distributed/pipeline/sync/test_worker', 'distributed/optim/test_zero_redundancy_optimizer', + "distributed/elastic/agent/server/test/api_test", ] ROCM_BLOCKLIST = [