[MPS] Enable Memory Leak Detection for test_mps.py (#94646)

- To check for Memory Leaks in `test_mps.py`, set the env-variable `PYTORCH_TEST_MPS_MEM_LEAK_CHECK=1` when running test_mps.py (used CUDA code as reference).
- Added support for the following new python interfaces in MPS module:
`torch.mps.[empty_cache(), set_per_process_memory_fraction(), current_allocated_memory(), driver_allocated_memory()]`
- Renamed `_is_mps_on_macos_13_or_newer()` to `_mps_is_on_macos_13_or_newer()`, and `_is_mps_available()` to `_mps_is_available()` to be consistent in naming with prefix `_mps`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94646
Approved by: https://github.com/malfet
This commit is contained in:
Ramin Azarmehr 2023-02-13 17:56:24 +00:00 committed by PyTorch MergeBot
parent ceb0f1576b
commit b57e6fdb50
9 changed files with 319 additions and 24 deletions

View file

@ -43,6 +43,22 @@ struct TORCH_API MPSHooksInterface {
virtual void deviceSynchronize() const {
AT_ERROR("Cannot synchronize MPS device without MPS backend.");
}
virtual void emptyCache() const {
AT_ERROR("Cannot execute emptyCache() without MPS backend.");
}
virtual size_t getCurrentAllocatedMemory() const {
AT_ERROR("Cannot execute getCurrentAllocatedMemory() without MPS backend.");
}
virtual size_t getDriverAllocatedMemory() const {
AT_ERROR("Cannot execute getDriverAllocatedMemory() without MPS backend.");
}
virtual void setMemoryFraction(double /*ratio*/) const {
AT_ERROR("Cannot execute setMemoryFraction() without MPS backend.");
}
};
struct TORCH_API MPSHooksArgs {};

View file

@ -3,6 +3,7 @@
#include <ATen/mps/MPSHooks.h>
#include <ATen/mps/MPSDevice.h>
#include <ATen/mps/MPSGeneratorImpl.h>
#include <ATen/mps/MPSAllocatorInterface.h>
namespace at {
namespace mps {
@ -32,6 +33,22 @@ void MPSHooks::deviceSynchronize() const {
at::mps::device_synchronize();
}
void MPSHooks::emptyCache() const {
at::mps::getIMPSAllocator()->emptyCache();
}
size_t MPSHooks::getCurrentAllocatedMemory() const {
return at::mps::getIMPSAllocator()->getCurrentAllocatedMemory();
}
size_t MPSHooks::getDriverAllocatedMemory() const {
return at::mps::getIMPSAllocator()->getDriverAllocatedMemory();
}
void MPSHooks::setMemoryFraction(double ratio) const {
at::mps::getIMPSAllocator()->setHighWatermarkRatio(ratio);
}
using at::MPSHooksRegistry;
using at::RegistererMPSHooksRegistry;

View file

@ -17,6 +17,10 @@ struct MPSHooks : public at::MPSHooksInterface {
Allocator* getMPSDeviceAllocator() const override;
const Generator& getDefaultMPSGenerator() const override;
void deviceSynchronize() const override;
void emptyCache() const override;
size_t getCurrentAllocatedMemory() const override;
size_t getDriverAllocatedMemory() const override;
void setMemoryFraction(double ratio) const override;
};
}} // at::mps

View file

@ -11,4 +11,8 @@ torch.mps
get_rng_state
set_rng_state
manual_seed
seed
seed
empty_cache
set_per_process_memory_fraction
current_allocated_memory
driver_allocated_memory

View file

@ -11,6 +11,7 @@ import tempfile
import os
import pprint
import copy
import gc
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -61,7 +62,137 @@ if not torch.backends.mps.is_available():
TestCase = object # noqa: F811
NNTestCase = object # noqa: F811
class MPSReluTest(TestCase):
# Determine whether to enable MPS memory leak check (uses same code as CUDA).
TEST_MPS_MEM_LEAK_CHECK = os.getenv('PYTORCH_TEST_MPS_MEM_LEAK_CHECK', '0') == '1'
def skipMPSMemoryLeakCheckIf(condition):
def dec(fn):
if getattr(fn, '_do_mps_memory_leak_check', True):
fn._do_mps_memory_leak_check = not condition
return fn
return dec
class MpsMemoryLeakCheck():
def __init__(self, testcase, name=None):
self.name = testcase.id() if name is None else name
self.testcase = testcase
def __enter__(self):
# Performs a gc if required (required if any memory is held)
caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
if caching_allocator_mem_allocated > 0:
gc.collect()
torch.mps.empty_cache()
# Acquires caching allocator and driver statistics before the test is run
self.caching_allocator_before = torch.mps.current_allocated_memory()
self.driver_before = torch.mps.driver_allocated_memory()
def __exit__(self, exec_type, exec_value, traceback):
# Don't check for leaks if an exception was thrown
if exec_type is not None:
return
# Compares caching allocator before/after statistics
# An increase in allocated memory is a discrepancy indicating a possible memory leak
discrepancy_detected = False
caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
if caching_allocator_mem_allocated > self.caching_allocator_before:
discrepancy_detected = True
# Short-circuits if no discrepancy detected
if not discrepancy_detected:
return
# Validates the discrepancy persists after garbage collection and
# is confirmed by the driver API
gc.collect()
torch.mps.empty_cache()
discrepancy_detected = True
# Query memory multiple items to ensure leak was not transient
for n in range(3):
caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
driver_mem_allocated = torch.mps.driver_allocated_memory()
caching_allocator_discrepancy = False
driver_discrepancy = False
if caching_allocator_mem_allocated > self.caching_allocator_before:
caching_allocator_discrepancy = True
if driver_mem_allocated > self.driver_before:
driver_discrepancy = True
if not(caching_allocator_discrepancy or driver_discrepancy):
# Leak was false positive, exit loop
discrepancy_detected = False
break
if caching_allocator_discrepancy and not driver_discrepancy:
# Just raises a warning if the leak is not validated by the driver API
msg = ("MPS caching allocator reports a memory leak not "
"verified by the driver API in {}! "
"Caching allocator allocated memory was {} and is now reported as {}. "
"MPS driver allocated memory was {} and is now {}.").format(
self.name, self.caching_allocator_before,
caching_allocator_mem_allocated, self.driver_before, driver_mem_allocated)
warnings.warn(msg)
elif caching_allocator_discrepancy and driver_discrepancy:
# A caching allocator discrepancy validated by the driver API is a failure
msg = ("MPS driver API confirmed a leak in {}! "
"Caching allocator allocated memory was {} and is now reported as {}. "
"MPS driver allocated memory was {} and is now {}.").format(
self.name, self.caching_allocator_before, caching_allocator_mem_allocated,
self.driver_before, driver_mem_allocated)
raise RuntimeError(msg)
# Expand TestCase class with Memory Leak Detection on MPS device
class TestCaseMPS(TestCase):
_do_mps_memory_leak_check = True
def __init__(self, method_name='runTest'):
super().__init__(method_name)
test_method = getattr(self, method_name, None)
if test_method is not None:
# Wraps the tested method if we should do MPS memory check.
if TEST_MPS_MEM_LEAK_CHECK:
if self._do_mps_memory_leak_check:
self.wrap_with_mps_policy(method_name, self.assertLeaksNoMpsTensors)
def assertLeaksNoMpsTensors(self, name=None):
name = self.id() if name is None else name
return MpsMemoryLeakCheck(self, name)
def wrap_with_mps_policy(self, method_name, policy):
test_method = getattr(self, method_name)
setattr(self, method_name, super().wrap_method_with_policy(test_method, policy))
# checks for leaks even if TEST_MPS_MEM_LEAK_CHECK is 0
def wrap_with_mps_memory_check(self, method):
return super().wrap_method_with_policy(method, self.assertLeaksNoMpsTensors)
class TestMemoryLeak(TestCaseMPS):
def test_mps_memory_leak_detection(self):
l = []
@self.wrap_with_mps_memory_check
def no_leak():
pass
# Trigger an intentional memory leak
@self.wrap_with_mps_memory_check
def leak_gpu0():
# increasing to 8MB to force acquiring a new block and overcome blocksize differences across platforms
l.append(torch.randn(1024 * 1024 * 8, device=torch.device("mps")))
no_leak()
# check if a runtime error for memory leak was emitted which would
# confirm whether memory leak detection worked successfully or not.
with self.assertRaisesRegex(RuntimeError, r"MPS driver API confirmed .+"):
leak_gpu0()
class MPSReluTest(TestCaseMPS):
def _npRelu(self, np_features):
return np.maximum(np_features, np.zeros(np_features.shape)).astype(np_features.dtype)
@ -113,7 +244,7 @@ class MPSReluTest(TestCase):
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
device="mps")
class MatmulTest(TestCase):
class MatmulTest(TestCaseMPS):
def _helper(self, shape_tensor_1, shape_tensor_2, expand_tensor_1_shape=None, expand_tensor_2_shape=None):
if expand_tensor_1_shape:
tensor1_mps = torch.randn(shape_tensor_1, device="mps").expand(expand_tensor_1_shape)
@ -152,7 +283,7 @@ class MatmulTest(TestCase):
self._helper((10, 3, 4), (4, 5))
class MPSLeakyReluTest(TestCase):
class MPSLeakyReluTest(TestCaseMPS):
def _npLeakyRelu(self, np_features, negative_slope=0.1):
return np.maximum(np_features, negative_slope * np_features).astype(np_features.dtype)
@ -189,7 +320,7 @@ class MPSLeakyReluTest(TestCase):
device="cpu")
class TestAvgPool(TestCase):
class TestAvgPool(TestCaseMPS):
def _sum_pool2d(self, x, kernel_size):
windows = torch.nn.functional.unfold(x, kernel_size=kernel_size, stride=kernel_size)
return torch.sum(windows, dim=1)
@ -239,7 +370,7 @@ class TestAvgPool(TestCase):
self.assertTrue(not torch.isnan(y).any())
class TestMPS(TestCase):
class TestMPS(TestCaseMPS):
def test_exp(self, device="mps", dtype=torch.float):
for v in (2, -2) + ((1j, 1 + 1j) if dtype.is_complex else ()):
b = torch.arange(18, device="cpu") / 3 * math.pi
@ -2479,7 +2610,7 @@ class TestMPS(TestCase):
helper((2, 8, 4, 5), torch.int16)
class TestLogical(TestCase):
class TestLogical(TestCaseMPS):
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad)
@ -2591,7 +2722,7 @@ class TestLogical(TestCase):
[helper(dtype) for dtype in [torch.float32, torch.float16, torch.int32, torch.int16, torch.uint8, torch.int8, torch.bool]]
class TestSmoothL1Loss(TestCase):
class TestSmoothL1Loss(TestCaseMPS):
def _smooth_l1_loss_helper(self, reduction="mean", requires_grad=False):
# CPU
@ -2630,7 +2761,7 @@ class TestSmoothL1Loss(TestCase):
self._smooth_l1_loss_helper(reduction="sum", requires_grad=True)
class TestNLLLoss(TestCase):
class TestNLLLoss(TestCaseMPS):
def test_nll_loss_mismatched_batch(self, device='mps'):
x = torch.randn((10, 3), requires_grad=True, device=device)
# t should have size (10,)
@ -6031,6 +6162,27 @@ class TestNLLLoss(TestCase):
x.backward(torch.randn_like(x))
torch.mps.synchronize()
def test_mps_allocator_module(self):
# first garbage collect and empty the cached blocks
gc.collect()
torch.mps.empty_cache()
# measure memory allocations from MPSAllocator
current_alloc_before = torch.mps.current_allocated_memory()
# after garbage collection and emptying the cache the
# current_allocated_memory must be zero
self.assertTrue(current_alloc_before == 0)
# measure total memory allocations from Metal driver
driver_alloc_before = torch.mps.driver_allocated_memory()
# allocate a new 8 MB tensor to force allocation of a new Metal Heap
x = torch.ones(1024 * 1024 * 8, device="mps")
# get memory allocations after allocating tensor x
current_alloc_after = torch.mps.current_allocated_memory()
driver_alloc_after = torch.mps.driver_allocated_memory()
# current and driver memory allocations must have
# grown at this point
self.assertTrue(current_alloc_after > current_alloc_before)
self.assertTrue(driver_alloc_after > driver_alloc_before)
# Test random_.to and random_.from
def test_random(self):
def helper(shape, low, high, dtype=torch.int32):
@ -6525,7 +6677,7 @@ class TestNNMPS(NNTestCase):
# self.assertEqual(expect, actual)
class TestConstantPadNd(TestCase):
class TestConstantPadNd(TestCaseMPS):
def test_preserves_memory_format(self):
nchw_tensor = torch.rand((1, 2, 5, 3))
nchw_padded = torch.constant_pad_nd(nchw_tensor, [1, 2], 0.5)
@ -6536,7 +6688,7 @@ class TestConstantPadNd(TestCase):
self.assertTrue(nhwc_padded.is_contiguous(memory_format=torch.channels_last))
class TestLinalgMPS(TestCase):
class TestLinalgMPS(TestCaseMPS):
def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False):
dtype = t.dtype
numpy_dtype = dtype
@ -6602,7 +6754,7 @@ class TestLinalgMPS(TestCase):
m2 = torch.randn(25, device=device).to(dtype)
self._test_addr(torch.addr, M, m1, m2, beta=0)
class TestGatherScatter(TestCase):
class TestGatherScatter(TestCaseMPS):
def test_slicing_with_step(self):
# Slicing with step
# https://github.com/pytorch/pytorch/issues/78886
@ -6667,7 +6819,7 @@ class TestGatherScatter(TestCase):
# They are subset of those tests as currently only this subset is working.
# This whole `class` will be removed when we add generic device testing. There
# are no additional tests added apart from what is part of test_view_ops.py
class TestViewOpsMPS(TestCase):
class TestViewOpsMPS(TestCaseMPS):
exact_dtype = True
def test_permute_slicing(self):
@ -7478,7 +7630,7 @@ class TestViewOpsMPS(TestCase):
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
self.assertEqual(x.view(6).shape, [6])
class TestConvolutionMPS(TestCase):
class TestConvolutionMPS(TestCaseMPS):
def test_conv1d_all_strides_paddings(self):
# https://github.com/pytorch/pytorch/issues/82921
def helper(stride, padding):
@ -7837,7 +7989,7 @@ class TestConvolutionMPS(TestCase):
msg="groundtruth comparison failed for mode={}, "
"padding_mode={}".format(mode, padding_mode))
class TestAdvancedIndexing(TestCase):
class TestAdvancedIndexing(TestCaseMPS):
supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]
supported_np_dtypes = [np.float32, np.float16, np.int64, np.int32, np.int16, np.uint8]
@ -8641,7 +8793,7 @@ class TestAdvancedIndexing(TestCase):
out = x[idx] # index
self.assertEqual(out, torch.zeros(2, device=device), atol=0, rtol=0)
class TestRNNMPS(TestCase):
class TestRNNMPS(TestCaseMPS):
def test_lstm_1(self, device="mps", dtype=torch.float32):
rnn = nn.LSTM(1, 4, 2, device="cpu")
@ -8851,7 +9003,7 @@ for t in [torch.double, torch.cdouble, torch.cfloat, torch.int8, torch.bfloat16]
del MPS_DTYPES[MPS_DTYPES.index(t)]
class TestConsistency(TestCase):
class TestConsistency(TestCaseMPS):
# TODO: This is only used while some ops are being added.
# This list should contain all ops and dtypes eventually
# This can be generated automatically in the `new_mps_allowlist.txt` file

View file

@ -1201,8 +1201,12 @@ def _multiprocessing_init() -> None: ...
# Defined in torch/csrc/mps/Module.cpp
def _mps_synchronize() -> None: ...
def _mps_get_default_generator() -> Generator: ...
def _is_mps_available() -> _bool: ...
def _is_mps_on_macos_13_or_newer() -> _bool: ...
def _mps_emptyCache() -> None: ...
def _mps_setMemoryFraction(fraction: _float) -> None: ...
def _mps_currentAllocatedMemory() -> _int: ...
def _mps_driverAllocatedMemory() -> _int: ...
def _mps_is_available() -> _bool: ...
def _mps_is_on_macos_13_or_newer() -> _bool: ...
# Defined in torch/csrc/cuda/Module.cpp
def _cuda_getCurrentStream(device: _int) -> Tuple: ...

View file

@ -15,13 +15,13 @@ def is_built() -> bool:
@_lru_cache()
def is_available() -> bool:
r"""Returns a bool indicating if MPS is currently available."""
return torch._C._is_mps_available()
return torch._C._mps_is_available()
@_lru_cache()
def is_macos13_or_newer() -> bool:
r"""Returns a bool indicating whether MPS is running on MacOS 13 or newer."""
return torch._C._is_mps_on_macos_13_or_newer()
return torch._C._mps_is_on_macos_13_or_newer()
# Register prims as implementation of var_mean and group_norm

View file

@ -1,6 +1,7 @@
#include <ATen/ATen.h>
#include <c10/util/CallOnce.h>
#include <torch/csrc/Generator.h>
#include <torch/csrc/THP.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/python_numbers.h>
@ -77,14 +78,51 @@ static PyObject* MPSModule_synchronize(PyObject* _unused, PyObject* noargs) {
END_HANDLE_TH_ERRORS
}
static PyObject* MPSModule_emptyCache(PyObject* _unused, PyObject* noargs) {
HANDLE_TH_ERRORS
at::detail::getMPSHooks().emptyCache();
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* MPSModule_setMemoryFraction(
PyObject* _unused,
PyObject* args) {
HANDLE_TH_ERRORS
THPUtils_assert(
THPUtils_checkDouble(args), "invalid argument to setMemoryFraction()");
double fraction = THPUtils_unpackDouble(args);
at::detail::getMPSHooks().setMemoryFraction(fraction);
END_HANDLE_TH_ERRORS
Py_RETURN_NONE;
}
static PyObject* MPSModule_currentAllocatedMemory(
PyObject* _unused,
PyObject* noargs) {
HANDLE_TH_ERRORS
return PyLong_FromUnsignedLongLong(
at::detail::getMPSHooks().getCurrentAllocatedMemory());
END_HANDLE_TH_ERRORS
}
static PyObject* MPSModule_driverAllocatedMemory(
PyObject* _unused,
PyObject* noargs) {
HANDLE_TH_ERRORS
return PyLong_FromUnsignedLongLong(
at::detail::getMPSHooks().getDriverAllocatedMemory());
END_HANDLE_TH_ERRORS
}
// NOLINTNEXTLINE(modernize-avoid-c-arrays,
// cppcoreguidelines-avoid-non-const-global-variables,
// cppcoreguidelines-avoid-c-arrays)
static struct PyMethodDef _MPSModule_methods[] = {
{"_mps_synchronize", MPSModule_synchronize, METH_NOARGS, nullptr},
{"_mps_is_in_bad_fork", MPSModule_isInBadFork, METH_NOARGS, nullptr},
{"_is_mps_available", MPSModule_isAvailable, METH_NOARGS, nullptr},
{"_is_mps_on_macos_13_or_newer",
{"_mps_is_available", MPSModule_isAvailable, METH_NOARGS, nullptr},
{"_mps_is_on_macos_13_or_newer",
MPSModule_isMacOS13orNewer,
METH_NOARGS,
nullptr},
@ -92,6 +130,16 @@ static struct PyMethodDef _MPSModule_methods[] = {
MPSModule_getDefaultMPSGenerator,
METH_NOARGS,
nullptr},
{"_mps_emptyCache", MPSModule_emptyCache, METH_NOARGS, nullptr},
{"_mps_setMemoryFraction", MPSModule_setMemoryFraction, METH_O, nullptr},
{"_mps_currentAllocatedMemory",
MPSModule_currentAllocatedMemory,
METH_NOARGS,
nullptr},
{"_mps_driverAllocatedMemory",
MPSModule_driverAllocatedMemory,
METH_NOARGS,
nullptr},
{nullptr}};
PyMethodDef* python_functions() {

View file

@ -50,5 +50,55 @@ def seed() -> None:
r"""Sets the seed for generating random numbers to a random number."""
_get_default_mps_generator().seed()
def empty_cache() -> None:
r"""Releases all unoccupied cached memory currently held by the caching
allocator so that those can be used in other GPU applications.
"""
torch._C._mps_emptyCache()
def set_per_process_memory_fraction(fraction) -> None:
r"""Set memory fraction for limiting process's memory allocation on MPS device.
The allowed value equals the fraction multiplied by recommended maximum device memory
(obtained from Metal API device.recommendedMaxWorkingSetSize).
If trying to allocate more than the allowed value in a process, it will raise an out of
memory error in allocator.
Args:
fraction(float): Range: 0~2. Allowed memory equals total_memory * fraction.
.. note::
Passing 0 to fraction means unlimited allocations
(may cause system failure if out of memory).
Passing fraction greater than 1.0 allows limits beyond the value
returned from device.recommendedMaxWorkingSetSize.
"""
if not isinstance(fraction, float):
raise TypeError('Invalid type for fraction argument, must be `float`')
if fraction < 0 or fraction > 2:
raise ValueError('Invalid fraction value: {}. Allowed range: 0~2'.format(fraction))
torch._C._mps_setMemoryFraction(fraction)
def current_allocated_memory() -> int:
r"""Returns the current GPU memory occupied by tensors in bytes.
.. note::
The returned size does not include cached allocations in
memory pools of MPSAllocator.
"""
return torch._C._mps_currentAllocatedMemory()
def driver_allocated_memory() -> int:
r"""Returns total GPU memory allocated by Metal driver for the process in bytes.
.. note::
The returned size includes cached allocations in MPSAllocator pools
as well as allocations from MPS/MPSGraph frameworks.
"""
return torch._C._mps_driverAllocatedMemory()
__all__ = [
'get_rng_state', 'manual_seed', 'seed', 'set_rng_state', 'synchronize']
'get_rng_state', 'manual_seed', 'seed', 'set_rng_state', 'synchronize',
'empty_cache', 'set_per_process_memory_fraction', 'current_allocated_memory',
'driver_allocated_memory']