From dcac3c3e06556bc0e729dd1fa75f4f1e81caa356 Mon Sep 17 00:00:00 2001 From: Hyunho Yeo Date: Fri, 7 Feb 2025 23:06:35 +0000 Subject: [PATCH] [MTIA] (2/n) Implement PyTorch APIs to query/reset device peak memory usage (#146659) Summary: Public summary (shared with Github): This diff implements the correct version of the PyTorch API "max_memory_allocated". Nit: The file previously contained two unit tests with the same name (due to wrong revert); I deleted a deprecated one to revamp the correct version. Test Plan: ``` buck2 test //mtia/host_runtime/torch_mtia/tests:test_torch_mtia_api -- -r test_max_memory_allocated ``` https://www.internalfb.com/intern/testinfra/testrun/12103424065182810 Reviewed By: yuhc Differential Revision: D68988435 Pull Request resolved: https://github.com/pytorch/pytorch/pull/146659 Approved by: https://github.com/nautsimon --- torch/mtia/memory.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/torch/mtia/memory.py b/torch/mtia/memory.py index 996a825890a..17d5471bec5 100644 --- a/torch/mtia/memory.py +++ b/torch/mtia/memory.py @@ -27,12 +27,13 @@ def max_memory_allocated(device: Optional[_device_t] = None) -> int: r"""Return the maximum memory allocated in bytes for a given device. Args: - device (torch.device or int, optional): selected device. Returns - statistic for the current device, given by :func:`~torch.mtia.current_device`, - if :attr:`device` is ``None`` (default). + device (torch.device, str, or int, optional) selected device. Returns + statistics for the current device, given by current_device(), + if device is None (default). """ - - return memory_stats(device=device).get("allocated_bytes.all.peak", 0) + if not is_initialized(): + return 0 + return memory_stats(device).get("dram", 0).get("peak_bytes", 0) __all__ = [