[MTIA] (2/n) Implement PyTorch APIs to query/reset device peak memory usage (#146659)

Summary:
Public summary (shared with Github): This diff implements the correct version of the PyTorch API "max_memory_allocated".

Nit: The file previously contained two unit tests with the same name (due to wrong revert); I deleted a deprecated one to revamp the correct version.

Test Plan:
```
buck2 test //mtia/host_runtime/torch_mtia/tests:test_torch_mtia_api -- -r test_max_memory_allocated
```

https://www.internalfb.com/intern/testinfra/testrun/12103424065182810

Reviewed By: yuhc

Differential Revision: D68988435

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146659
Approved by: https://github.com/nautsimon
This commit is contained in:
Hyunho Yeo 2025-02-07 23:06:35 +00:00 committed by PyTorch MergeBot
parent fa34128435
commit dcac3c3e06

View file

@ -27,12 +27,13 @@ def max_memory_allocated(device: Optional[_device_t] = None) -> int:
r"""Return the maximum memory allocated in bytes for a given device.
Args:
device (torch.device or int, optional): selected device. Returns
statistic for the current device, given by :func:`~torch.mtia.current_device`,
if :attr:`device` is ``None`` (default).
device (torch.device, str, or int, optional) selected device. Returns
statistics for the current device, given by current_device(),
if device is None (default).
"""
return memory_stats(device=device).get("allocated_bytes.all.peak", 0)
if not is_initialized():
return 0
return memory_stats(device).get("dram", 0).get("peak_bytes", 0)
__all__ = [