Pytorch add agent api tests (#56985)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56985

Pytorch add agent api tests

Test Plan: ci/cd

Reviewed By: cbalioglu

Differential Revision: D28020485

fbshipit-source-id: e6acf095f26ce4b99cddfbf7641fb4fa885b0c86
This commit is contained in:
Aliaksandr Ivanou 2021-04-29 06:11:18 -07:00 committed by Facebook GitHub Bot
parent 3a923a555a
commit 5c8ceefe46
2 changed files with 18 additions and 17 deletions

View file

@ -24,7 +24,8 @@ from torch.distributed.elastic.agent.server.api import (
)
from torch.distributed.elastic.multiprocessing.errors import ProcessFailure
from torch.distributed.elastic.rendezvous import RendezvousHandler, RendezvousParameters
from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
from torch.distributed.elastic.utils.distributed import get_free_port
from torch.testing._internal.common_utils import run_tests
def do_nothing():
@ -149,17 +150,6 @@ def monres(state: WorkerState):
class SimpleElasticAgentTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
# start a standalone, single process etcd server to use for all tests
cls._etcd_server = EtcdServer()
cls._etcd_server.start()
@classmethod
def tearDownClass(cls):
# stop the standalone etcd server
cls._etcd_server.stop()
def _get_worker_spec(
self,
max_restarts=1,
@ -168,10 +158,15 @@ class SimpleElasticAgentTest(unittest.TestCase):
local_world_size=8,
):
run_id = str(uuid.uuid4().int)
endpoint = self._etcd_server.get_endpoint()
port = get_free_port()
endpoint = f"127.0.0.1:{port}"
rdzv_params = RendezvousParameters(
backend="etcd", endpoint=endpoint, run_id=run_id, min_nodes=1, max_nodes=1
backend="static",
endpoint=endpoint,
run_id=run_id,
min_nodes=1,
max_nodes=1,
rank=0,
)
rdzv_handler = rdzv_registry.get_rendezvous_handler(rdzv_params)
spec = WorkerSpec(
@ -536,7 +531,7 @@ class SimpleElasticAgentTest(unittest.TestCase):
agent = TestAgent(spec)
actual_event = agent.get_agent_status_event(state=WorkerState.SUCCEEDED)
self.assertEqual("AGENT", actual_event.source)
self.assertEqual("etcd", actual_event.metadata["rdzv_backend"])
self.assertEqual("static", actual_event.metadata["rdzv_backend"])
self.assertEqual(WorkerState.SUCCEEDED.value, actual_event.metadata["state"])
self.assertEqual(spec.role, actual_event.metadata["role"])
@ -550,7 +545,11 @@ class SimpleElasticAgentTest(unittest.TestCase):
worker=agent._worker_group.workers[0],
)
self.assertEqual("WORKER", actual_event.source)
self.assertEqual("etcd", actual_event.metadata["rdzv_backend"])
self.assertEqual("static", actual_event.metadata["rdzv_backend"])
self.assertEqual(WorkerState.SUCCEEDED.value, actual_event.metadata["state"])
self.assertEqual(spec.role, actual_event.metadata["role"])
self.assertEqual(2, actual_event.metadata["agent_restarts"])
if __name__ == "__main__":
run_tests()

View file

@ -190,6 +190,7 @@ USE_PYTEST_LIST = [
'distributions/test_utils',
'test_typing',
"distributed/elastic/events/lib_test",
"distributed/elastic/agent/server/test/api_test",
]
WINDOWS_BLOCKLIST = [
@ -223,6 +224,7 @@ WINDOWS_BLOCKLIST = [
'distributed/pipeline/sync/test_transparency',
'distributed/pipeline/sync/test_worker',
'distributed/optim/test_zero_redundancy_optimizer',
"distributed/elastic/agent/server/test/api_test",
]
ROCM_BLOCKLIST = [