[DCP] Fix duplicated logging messages when enable both c10d and dcp l… (#130423)

…ogger

Fixes #129951 . Would you take a moment to review it? @LucasLLC

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130423
Approved by: https://github.com/Skylion007
This commit is contained in:
cdzhan 2024-07-11 13:43:38 +00:00 committed by PyTorch MergeBot
parent 0d66ccaf23
commit bac10cdd6f
2 changed files with 7 additions and 1 deletions

View file

@ -11,6 +11,8 @@ from torch.distributed._shard.sharded_tensor import (
ShardMetadata,
)
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
from torch.distributed.c10d_logger import _c10d_logger
from torch.distributed.checkpoint.logger import _dcp_logger
from torch.distributed.checkpoint.metadata import MetadataIndex
from torch.distributed.checkpoint.utils import find_state_dict_object
@ -122,6 +124,10 @@ class TestMedatadaIndex(TestCase):
with self.assertRaisesRegex(ValueError, "Could not find shard"):
find_state_dict_object(state_dict, MetadataIndex("st", [1]))
def test_dcp_logger(self):
self.assertTrue(_c10d_logger is not _dcp_logger)
self.assertEqual(1, len(_c10d_logger.handlers))
if __name__ == "__main__":
run_tests()

View file

@ -40,7 +40,7 @@ def _get_logging_handler(
destination: str = _DEFAULT_DESTINATION,
) -> Tuple[logging.Handler, str]:
log_handler = _log_handlers[destination]
log_handler_name = type(log_handler).__name__
log_handler_name = f"{type(log_handler).__name__}-{destination}"
return (log_handler, log_handler_name)