diff --git a/test/distributed/checkpoint/test_utils.py b/test/distributed/checkpoint/test_utils.py index 78d97f06995..b370feb3e1a 100644 --- a/test/distributed/checkpoint/test_utils.py +++ b/test/distributed/checkpoint/test_utils.py @@ -11,6 +11,8 @@ from torch.distributed._shard.sharded_tensor import ( ShardMetadata, ) from torch.distributed._shard.sharded_tensor.metadata import TensorProperties +from torch.distributed.c10d_logger import _c10d_logger +from torch.distributed.checkpoint.logger import _dcp_logger from torch.distributed.checkpoint.metadata import MetadataIndex from torch.distributed.checkpoint.utils import find_state_dict_object @@ -122,6 +124,10 @@ class TestMedatadaIndex(TestCase): with self.assertRaisesRegex(ValueError, "Could not find shard"): find_state_dict_object(state_dict, MetadataIndex("st", [1])) + def test_dcp_logger(self): + self.assertTrue(_c10d_logger is not _dcp_logger) + self.assertEqual(1, len(_c10d_logger.handlers)) + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/c10d_logger.py b/torch/distributed/c10d_logger.py index 2c92176c53e..162cb62f992 100644 --- a/torch/distributed/c10d_logger.py +++ b/torch/distributed/c10d_logger.py @@ -40,7 +40,7 @@ def _get_logging_handler( destination: str = _DEFAULT_DESTINATION, ) -> Tuple[logging.Handler, str]: log_handler = _log_handlers[destination] - log_handler_name = type(log_handler).__name__ + log_handler_name = f"{type(log_handler).__name__}-{destination}" return (log_handler, log_handler_name)