From b0cfa96e82d7fbd02f5dbcef2632714caf89615d Mon Sep 17 00:00:00 2001 From: Kurman Karabukaev Date: Sat, 2 Mar 2024 08:07:52 +0000 Subject: [PATCH] [Torchelastic][Logging] Pluggable logsspecs using python entrypoints and option to specify one by name. (#120942) Summary: Expose an option to users to specify name of the LogsSpec implementation to use. - Has to be defined in entrypoints under `torchrun.logs_specs` group. - Must implement LogsSpec defined in prior PR/diff. Test Plan: unit test+local tests Reviewed By: ezyang Differential Revision: D54180838 Pull Request resolved: https://github.com/pytorch/pytorch/pull/120942 Approved by: https://github.com/ezyang --- setup.py | 5 +- test/distributed/launcher/run_test.py | 54 ++++++++++++++++++- .../elastic/multiprocessing/__init__.py | 2 + .../elastic/multiprocessing/api.py | 17 ++++-- torch/distributed/launcher/api.py | 8 +-- torch/distributed/run.py | 47 ++++++++++++++-- 6 files changed, 116 insertions(+), 17 deletions(-) diff --git a/setup.py b/setup.py index 7eec481813d..84cd3fd0403 100644 --- a/setup.py +++ b/setup.py @@ -1055,7 +1055,10 @@ def configure_extension_build(): "convert-caffe2-to-onnx = caffe2.python.onnx.bin.conversion:caffe2_to_onnx", "convert-onnx-to-caffe2 = caffe2.python.onnx.bin.conversion:onnx_to_caffe2", "torchrun = torch.distributed.run:main", - ] + ], + "torchrun.logs_specs": [ + "default = torch.distributed.elastic.multiprocessing:DefaultLogsSpecs", + ], } return extensions, cmdclass, packages, entry_points, extra_install_requires diff --git a/test/distributed/launcher/run_test.py b/test/distributed/launcher/run_test.py index be30e0fb258..f33d075d8a7 100644 --- a/test/distributed/launcher/run_test.py +++ b/test/distributed/launcher/run_test.py @@ -17,17 +17,18 @@ import unittest import uuid from contextlib import closing from unittest import mock -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import torch.distributed.run as launch from torch.distributed.elastic.agent.server.api import RunResult, WorkerState +from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs from torch.distributed.elastic.multiprocessing.errors import ChildFailedError from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer from torch.distributed.elastic.utils import get_socket_with_port from torch.distributed.elastic.utils.distributed import get_free_port from torch.testing._internal.common_utils import ( - TEST_WITH_DEV_DBG_ASAN, skip_but_pass_in_sandcastle_if, + TEST_WITH_DEV_DBG_ASAN, ) @@ -504,6 +505,55 @@ class ElasticLaunchTest(unittest.TestCase): is_torchelastic_launched = fp.readline() self.assertEqual("True", is_torchelastic_launched) + @patch("torch.distributed.run.metadata") + @skip_but_pass_in_sandcastle_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") + def test_is_torchelastic_launched_with_logs_spec_defined(self, metadata_mock): + # mock the entrypoint API to avoid version issues. + entrypoints = MagicMock() + metadata_mock.entry_points.return_value = entrypoints + + group = MagicMock() + entrypoints.select.return_value = group + + ep = MagicMock() + ep.load.return_value = DefaultLogsSpecs + + group.select.return_value = (ep) + group.__getitem__.return_value = ep + + out_file = f"{os.path.join(self.test_dir, 'out')}" + if os.path.exists(out_file): + os.remove(out_file) + launch.main( + [ + "--run-path", + "--nnodes=1", + "--nproc-per-node=1", + "--monitor-interval=1", + "--logs_specs=default", + path("bin/test_script_is_torchelastic_launched.py"), + f"--out-file={out_file}", + ] + ) + + with open(out_file) as fp: + is_torchelastic_launched = fp.readline() + self.assertEqual("True", is_torchelastic_launched) + + @skip_but_pass_in_sandcastle_if(TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan") + def test_logs_logs_spec_entrypoint_must_be_defined(self): + with self.assertRaises(ValueError): + launch.main( + [ + "--run-path", + "--nnodes=1", + "--nproc-per-node=1", + "--monitor-interval=1", + "--logs_specs=DOESNOT_EXIST", + path("bin/test_script_is_torchelastic_launched.py"), + ] + ) + def test_is_not_torchelastic_launched(self): # launch test script without torchelastic and validate that # torch.distributed.is_torchelastic_launched() returns False diff --git a/torch/distributed/elastic/multiprocessing/__init__.py b/torch/distributed/elastic/multiprocessing/__init__.py index db60a66ad07..d7e6a55406f 100644 --- a/torch/distributed/elastic/multiprocessing/__init__.py +++ b/torch/distributed/elastic/multiprocessing/__init__.py @@ -68,6 +68,7 @@ from typing import Callable, Dict, Optional, Tuple, Union, Set from torch.distributed.elastic.multiprocessing.api import ( # noqa: F401 _validate_full_rank, DefaultLogsSpecs, + LogsDest, LogsSpecs, MultiprocessContext, PContext, @@ -88,6 +89,7 @@ __all__ = [ "RunProcsResult", "SignalException", "Std", + "LogsDest", "LogsSpecs", "DefaultLogsSpecs", "SubprocessContext", diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py index c53477ee7e9..b0bd8d117fe 100644 --- a/torch/distributed/elastic/multiprocessing/api.py +++ b/torch/distributed/elastic/multiprocessing/api.py @@ -184,10 +184,18 @@ class LogsDest: class LogsSpecs(ABC): """ Defines logs processing and redirection for each worker process. + Args: - log_dir: base directory where logs will be written - redirects: specifies which streams to redirect to files. - tee: specifies which streams to duplicate to stdout/stderr + log_dir: + Base directory where logs will be written. + redirects: + Streams to redirect to files. Pass a single ``Std`` + enum to redirect for all workers, or a mapping keyed + by local_rank to selectively redirect. + tee: + Streams to duplicate to stdout/stderr. + Pass a single ``Std`` enum to duplicate streams for all workers, + or a mapping keyed by local_rank to selectively duplicate. """ def __init__( @@ -220,7 +228,8 @@ class LogsSpecs(ABC): class DefaultLogsSpecs(LogsSpecs): """ Default LogsSpecs implementation: - - `log_dir` will be created if it doesn't exist and it is not set to os.devnull + + - `log_dir` will be created if it doesn't exist and it is not set to `os.devnull` - Generates nested folders for each attempt and rank. """ def __init__( diff --git a/torch/distributed/launcher/api.py b/torch/distributed/launcher/api.py index 214d043bcb4..f2b4aca644f 100644 --- a/torch/distributed/launcher/api.py +++ b/torch/distributed/launcher/api.py @@ -54,12 +54,6 @@ class LaunchConfig: as a period of monitoring workers. start_method: The method is used by the elastic agent to start the workers (spawn, fork, forkserver). - log_dir: base log directory where log files are written. If not set, - one is created in a tmp dir but NOT removed on exit. - redirects: configuration to redirect stdout/stderr to log files. - Pass a single ``Std`` enum to redirect all workers, - or a mapping keyed by local_rank to selectively redirect. - tee: configuration to "tee" stdout/stderr to console + log file. metrics_cfg: configuration to initialize metrics. local_addr: address of the local node if any. If not set, a lookup on the local machine's FQDN will be performed. @@ -248,9 +242,9 @@ def launch_agent( agent = LocalElasticAgent( spec=spec, + logs_specs=config.logs_specs, # type: ignore[arg-type] start_method=config.start_method, log_line_prefix_template=config.log_line_prefix_template, - logs_specs=config.logs_specs, # type: ignore[arg-type] ) shutdown_rdzv = True diff --git a/torch/distributed/run.py b/torch/distributed/run.py index b6f2fdd1c47..4928f6c4119 100644 --- a/torch/distributed/run.py +++ b/torch/distributed/run.py @@ -375,12 +375,13 @@ import logging import os import sys import uuid +import importlib.metadata as metadata from argparse import REMAINDER, ArgumentParser -from typing import Callable, List, Tuple, Union, Optional, Set +from typing import Callable, List, Tuple, Type, Union, Optional, Set import torch from torch.distributed.argparse_util import check_env, env -from torch.distributed.elastic.multiprocessing import Std, DefaultLogsSpecs +from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, LogsSpecs, Std from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.elastic.rendezvous.utils import _parse_rendezvous_config from torch.distributed.elastic.utils import macros @@ -602,6 +603,15 @@ def get_args_parser() -> ArgumentParser: "machine's FQDN.", ) + parser.add_argument( + "--logs-specs", + "--logs_specs", + default=None, + type=str, + help="torchrun.logs_specs group entrypoint name, value must be type of LogsSpecs. " + "Can be used to override custom logging behavior.", + ) + # # Positional arguments. # @@ -699,6 +709,36 @@ def get_use_env(args) -> bool: return args.use_env +def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]: + """ + Attemps to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param. + Provides plugin mechanism to provide custom implementation of LogsSpecs. + + Returns `DefaultLogsSpecs` when logs_spec_name is None. + Raises ValueError when entrypoint for `logs_spec_name` can't be found in entrypoints. + """ + logs_specs_cls = None + if logs_specs_name is not None: + eps = metadata.entry_points() + if hasattr(eps, "select"): # >= 3.10 + group = eps.select(group="torchrun.logs_specs") + if group.select(name=logs_specs_name): + logs_specs_cls = group[logs_specs_name].load() + + elif specs := eps.get("torchrun.logs_specs"): # < 3.10 + if entrypoint_list := [ep for ep in specs if ep.name == logs_specs_name]: + logs_specs_cls = entrypoint_list[0].load() + + if logs_specs_cls is None: + raise ValueError(f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key") + + logging.info("Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls)) + else: + logs_specs_cls = DefaultLogsSpecs + + return logs_specs_cls + + def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str]]: # If ``args`` not passed, defaults to ``sys.argv[:1]`` min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes) @@ -745,7 +785,8 @@ def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str "--local_ranks_filter must be a comma-separated list of integers e.g. --local_ranks_filter=0,1,2" ) from e - logs_specs = DefaultLogsSpecs( + logs_specs_cls: Type[LogsSpecs] = _get_logs_specs_class(args.logs_specs) + logs_specs = logs_specs_cls( log_dir=args.log_dir, redirects=Std.from_str(args.redirects), tee=Std.from_str(args.tee),