From bac10cdd6f8ac520a461d6711cf758b69ff2d9a1 Mon Sep 17 00:00:00 2001 From: cdzhan Date: Thu, 11 Jul 2024 13:43:38 +0000 Subject: [PATCH] =?UTF-8?q?[DCP]=20Fix=20duplicated=20logging=20messages?= =?UTF-8?q?=20when=20enable=20both=20c10d=20and=20dcp=20l=E2=80=A6=20(#130?= =?UTF-8?q?423)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ogger Fixes #129951 . Would you take a moment to review it? @LucasLLC Pull Request resolved: https://github.com/pytorch/pytorch/pull/130423 Approved by: https://github.com/Skylion007 --- test/distributed/checkpoint/test_utils.py | 6 ++++++ torch/distributed/c10d_logger.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/test/distributed/checkpoint/test_utils.py b/test/distributed/checkpoint/test_utils.py index 78d97f06995..b370feb3e1a 100644 --- a/test/distributed/checkpoint/test_utils.py +++ b/test/distributed/checkpoint/test_utils.py @@ -11,6 +11,8 @@ from torch.distributed._shard.sharded_tensor import ( ShardMetadata, ) from torch.distributed._shard.sharded_tensor.metadata import TensorProperties +from torch.distributed.c10d_logger import _c10d_logger +from torch.distributed.checkpoint.logger import _dcp_logger from torch.distributed.checkpoint.metadata import MetadataIndex from torch.distributed.checkpoint.utils import find_state_dict_object @@ -122,6 +124,10 @@ class TestMedatadaIndex(TestCase): with self.assertRaisesRegex(ValueError, "Could not find shard"): find_state_dict_object(state_dict, MetadataIndex("st", [1])) + def test_dcp_logger(self): + self.assertTrue(_c10d_logger is not _dcp_logger) + self.assertEqual(1, len(_c10d_logger.handlers)) + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/c10d_logger.py b/torch/distributed/c10d_logger.py index 2c92176c53e..162cb62f992 100644 --- a/torch/distributed/c10d_logger.py +++ b/torch/distributed/c10d_logger.py @@ -40,7 +40,7 @@ def _get_logging_handler( destination: str = _DEFAULT_DESTINATION, ) -> Tuple[logging.Handler, str]: log_handler = _log_handlers[destination] - log_handler_name = type(log_handler).__name__ + log_handler_name = f"{type(log_handler).__name__}-{destination}" return (log_handler, log_handler_name)