mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[MPS] Enable Memory Leak Detection for test_mps.py (#94646)
- To check for Memory Leaks in `test_mps.py`, set the env-variable `PYTORCH_TEST_MPS_MEM_LEAK_CHECK=1` when running test_mps.py (used CUDA code as reference). - Added support for the following new python interfaces in MPS module: `torch.mps.[empty_cache(), set_per_process_memory_fraction(), current_allocated_memory(), driver_allocated_memory()]` - Renamed `_is_mps_on_macos_13_or_newer()` to `_mps_is_on_macos_13_or_newer()`, and `_is_mps_available()` to `_mps_is_available()` to be consistent in naming with prefix `_mps`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/94646 Approved by: https://github.com/malfet
This commit is contained in:
parent
ceb0f1576b
commit
b57e6fdb50
9 changed files with 319 additions and 24 deletions
|
|
@ -43,6 +43,22 @@ struct TORCH_API MPSHooksInterface {
|
|||
virtual void deviceSynchronize() const {
|
||||
AT_ERROR("Cannot synchronize MPS device without MPS backend.");
|
||||
}
|
||||
|
||||
virtual void emptyCache() const {
|
||||
AT_ERROR("Cannot execute emptyCache() without MPS backend.");
|
||||
}
|
||||
|
||||
virtual size_t getCurrentAllocatedMemory() const {
|
||||
AT_ERROR("Cannot execute getCurrentAllocatedMemory() without MPS backend.");
|
||||
}
|
||||
|
||||
virtual size_t getDriverAllocatedMemory() const {
|
||||
AT_ERROR("Cannot execute getDriverAllocatedMemory() without MPS backend.");
|
||||
}
|
||||
|
||||
virtual void setMemoryFraction(double /*ratio*/) const {
|
||||
AT_ERROR("Cannot execute setMemoryFraction() without MPS backend.");
|
||||
}
|
||||
};
|
||||
|
||||
struct TORCH_API MPSHooksArgs {};
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#include <ATen/mps/MPSHooks.h>
|
||||
#include <ATen/mps/MPSDevice.h>
|
||||
#include <ATen/mps/MPSGeneratorImpl.h>
|
||||
#include <ATen/mps/MPSAllocatorInterface.h>
|
||||
|
||||
namespace at {
|
||||
namespace mps {
|
||||
|
|
@ -32,6 +33,22 @@ void MPSHooks::deviceSynchronize() const {
|
|||
at::mps::device_synchronize();
|
||||
}
|
||||
|
||||
void MPSHooks::emptyCache() const {
|
||||
at::mps::getIMPSAllocator()->emptyCache();
|
||||
}
|
||||
|
||||
size_t MPSHooks::getCurrentAllocatedMemory() const {
|
||||
return at::mps::getIMPSAllocator()->getCurrentAllocatedMemory();
|
||||
}
|
||||
|
||||
size_t MPSHooks::getDriverAllocatedMemory() const {
|
||||
return at::mps::getIMPSAllocator()->getDriverAllocatedMemory();
|
||||
}
|
||||
|
||||
void MPSHooks::setMemoryFraction(double ratio) const {
|
||||
at::mps::getIMPSAllocator()->setHighWatermarkRatio(ratio);
|
||||
}
|
||||
|
||||
using at::MPSHooksRegistry;
|
||||
using at::RegistererMPSHooksRegistry;
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,10 @@ struct MPSHooks : public at::MPSHooksInterface {
|
|||
Allocator* getMPSDeviceAllocator() const override;
|
||||
const Generator& getDefaultMPSGenerator() const override;
|
||||
void deviceSynchronize() const override;
|
||||
void emptyCache() const override;
|
||||
size_t getCurrentAllocatedMemory() const override;
|
||||
size_t getDriverAllocatedMemory() const override;
|
||||
void setMemoryFraction(double ratio) const override;
|
||||
};
|
||||
|
||||
}} // at::mps
|
||||
|
|
|
|||
|
|
@ -11,4 +11,8 @@ torch.mps
|
|||
get_rng_state
|
||||
set_rng_state
|
||||
manual_seed
|
||||
seed
|
||||
seed
|
||||
empty_cache
|
||||
set_per_process_memory_fraction
|
||||
current_allocated_memory
|
||||
driver_allocated_memory
|
||||
184
test/test_mps.py
184
test/test_mps.py
|
|
@ -11,6 +11,7 @@ import tempfile
|
|||
import os
|
||||
import pprint
|
||||
import copy
|
||||
import gc
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -61,7 +62,137 @@ if not torch.backends.mps.is_available():
|
|||
TestCase = object # noqa: F811
|
||||
NNTestCase = object # noqa: F811
|
||||
|
||||
class MPSReluTest(TestCase):
|
||||
# Determine whether to enable MPS memory leak check (uses same code as CUDA).
|
||||
TEST_MPS_MEM_LEAK_CHECK = os.getenv('PYTORCH_TEST_MPS_MEM_LEAK_CHECK', '0') == '1'
|
||||
|
||||
def skipMPSMemoryLeakCheckIf(condition):
|
||||
def dec(fn):
|
||||
if getattr(fn, '_do_mps_memory_leak_check', True):
|
||||
fn._do_mps_memory_leak_check = not condition
|
||||
return fn
|
||||
return dec
|
||||
|
||||
class MpsMemoryLeakCheck():
|
||||
def __init__(self, testcase, name=None):
|
||||
self.name = testcase.id() if name is None else name
|
||||
self.testcase = testcase
|
||||
|
||||
def __enter__(self):
|
||||
# Performs a gc if required (required if any memory is held)
|
||||
caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
|
||||
if caching_allocator_mem_allocated > 0:
|
||||
gc.collect()
|
||||
torch.mps.empty_cache()
|
||||
|
||||
# Acquires caching allocator and driver statistics before the test is run
|
||||
self.caching_allocator_before = torch.mps.current_allocated_memory()
|
||||
self.driver_before = torch.mps.driver_allocated_memory()
|
||||
|
||||
def __exit__(self, exec_type, exec_value, traceback):
|
||||
# Don't check for leaks if an exception was thrown
|
||||
if exec_type is not None:
|
||||
return
|
||||
# Compares caching allocator before/after statistics
|
||||
# An increase in allocated memory is a discrepancy indicating a possible memory leak
|
||||
discrepancy_detected = False
|
||||
caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
|
||||
if caching_allocator_mem_allocated > self.caching_allocator_before:
|
||||
discrepancy_detected = True
|
||||
|
||||
# Short-circuits if no discrepancy detected
|
||||
if not discrepancy_detected:
|
||||
return
|
||||
# Validates the discrepancy persists after garbage collection and
|
||||
# is confirmed by the driver API
|
||||
gc.collect()
|
||||
torch.mps.empty_cache()
|
||||
|
||||
discrepancy_detected = True
|
||||
# Query memory multiple items to ensure leak was not transient
|
||||
for n in range(3):
|
||||
caching_allocator_mem_allocated = torch.mps.current_allocated_memory()
|
||||
driver_mem_allocated = torch.mps.driver_allocated_memory()
|
||||
|
||||
caching_allocator_discrepancy = False
|
||||
driver_discrepancy = False
|
||||
|
||||
if caching_allocator_mem_allocated > self.caching_allocator_before:
|
||||
caching_allocator_discrepancy = True
|
||||
|
||||
if driver_mem_allocated > self.driver_before:
|
||||
driver_discrepancy = True
|
||||
|
||||
if not(caching_allocator_discrepancy or driver_discrepancy):
|
||||
# Leak was false positive, exit loop
|
||||
discrepancy_detected = False
|
||||
break
|
||||
|
||||
if caching_allocator_discrepancy and not driver_discrepancy:
|
||||
# Just raises a warning if the leak is not validated by the driver API
|
||||
msg = ("MPS caching allocator reports a memory leak not "
|
||||
"verified by the driver API in {}! "
|
||||
"Caching allocator allocated memory was {} and is now reported as {}. "
|
||||
"MPS driver allocated memory was {} and is now {}.").format(
|
||||
self.name, self.caching_allocator_before,
|
||||
caching_allocator_mem_allocated, self.driver_before, driver_mem_allocated)
|
||||
warnings.warn(msg)
|
||||
elif caching_allocator_discrepancy and driver_discrepancy:
|
||||
# A caching allocator discrepancy validated by the driver API is a failure
|
||||
msg = ("MPS driver API confirmed a leak in {}! "
|
||||
"Caching allocator allocated memory was {} and is now reported as {}. "
|
||||
"MPS driver allocated memory was {} and is now {}.").format(
|
||||
self.name, self.caching_allocator_before, caching_allocator_mem_allocated,
|
||||
self.driver_before, driver_mem_allocated)
|
||||
|
||||
raise RuntimeError(msg)
|
||||
|
||||
# Expand TestCase class with Memory Leak Detection on MPS device
|
||||
class TestCaseMPS(TestCase):
|
||||
_do_mps_memory_leak_check = True
|
||||
|
||||
def __init__(self, method_name='runTest'):
|
||||
super().__init__(method_name)
|
||||
test_method = getattr(self, method_name, None)
|
||||
if test_method is not None:
|
||||
# Wraps the tested method if we should do MPS memory check.
|
||||
if TEST_MPS_MEM_LEAK_CHECK:
|
||||
if self._do_mps_memory_leak_check:
|
||||
self.wrap_with_mps_policy(method_name, self.assertLeaksNoMpsTensors)
|
||||
|
||||
def assertLeaksNoMpsTensors(self, name=None):
|
||||
name = self.id() if name is None else name
|
||||
return MpsMemoryLeakCheck(self, name)
|
||||
|
||||
def wrap_with_mps_policy(self, method_name, policy):
|
||||
test_method = getattr(self, method_name)
|
||||
setattr(self, method_name, super().wrap_method_with_policy(test_method, policy))
|
||||
|
||||
# checks for leaks even if TEST_MPS_MEM_LEAK_CHECK is 0
|
||||
def wrap_with_mps_memory_check(self, method):
|
||||
return super().wrap_method_with_policy(method, self.assertLeaksNoMpsTensors)
|
||||
|
||||
class TestMemoryLeak(TestCaseMPS):
|
||||
def test_mps_memory_leak_detection(self):
|
||||
l = []
|
||||
|
||||
@self.wrap_with_mps_memory_check
|
||||
def no_leak():
|
||||
pass
|
||||
|
||||
# Trigger an intentional memory leak
|
||||
@self.wrap_with_mps_memory_check
|
||||
def leak_gpu0():
|
||||
# increasing to 8MB to force acquiring a new block and overcome blocksize differences across platforms
|
||||
l.append(torch.randn(1024 * 1024 * 8, device=torch.device("mps")))
|
||||
|
||||
no_leak()
|
||||
|
||||
# check if a runtime error for memory leak was emitted which would
|
||||
# confirm whether memory leak detection worked successfully or not.
|
||||
with self.assertRaisesRegex(RuntimeError, r"MPS driver API confirmed .+"):
|
||||
leak_gpu0()
|
||||
|
||||
class MPSReluTest(TestCaseMPS):
|
||||
def _npRelu(self, np_features):
|
||||
return np.maximum(np_features, np.zeros(np_features.shape)).astype(np_features.dtype)
|
||||
|
||||
|
|
@ -113,7 +244,7 @@ class MPSReluTest(TestCase):
|
|||
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
|
||||
device="mps")
|
||||
|
||||
class MatmulTest(TestCase):
|
||||
class MatmulTest(TestCaseMPS):
|
||||
def _helper(self, shape_tensor_1, shape_tensor_2, expand_tensor_1_shape=None, expand_tensor_2_shape=None):
|
||||
if expand_tensor_1_shape:
|
||||
tensor1_mps = torch.randn(shape_tensor_1, device="mps").expand(expand_tensor_1_shape)
|
||||
|
|
@ -152,7 +283,7 @@ class MatmulTest(TestCase):
|
|||
self._helper((10, 3, 4), (4, 5))
|
||||
|
||||
|
||||
class MPSLeakyReluTest(TestCase):
|
||||
class MPSLeakyReluTest(TestCaseMPS):
|
||||
def _npLeakyRelu(self, np_features, negative_slope=0.1):
|
||||
return np.maximum(np_features, negative_slope * np_features).astype(np_features.dtype)
|
||||
|
||||
|
|
@ -189,7 +320,7 @@ class MPSLeakyReluTest(TestCase):
|
|||
device="cpu")
|
||||
|
||||
|
||||
class TestAvgPool(TestCase):
|
||||
class TestAvgPool(TestCaseMPS):
|
||||
def _sum_pool2d(self, x, kernel_size):
|
||||
windows = torch.nn.functional.unfold(x, kernel_size=kernel_size, stride=kernel_size)
|
||||
return torch.sum(windows, dim=1)
|
||||
|
|
@ -239,7 +370,7 @@ class TestAvgPool(TestCase):
|
|||
self.assertTrue(not torch.isnan(y).any())
|
||||
|
||||
|
||||
class TestMPS(TestCase):
|
||||
class TestMPS(TestCaseMPS):
|
||||
def test_exp(self, device="mps", dtype=torch.float):
|
||||
for v in (2, -2) + ((1j, 1 + 1j) if dtype.is_complex else ()):
|
||||
b = torch.arange(18, device="cpu") / 3 * math.pi
|
||||
|
|
@ -2479,7 +2610,7 @@ class TestMPS(TestCase):
|
|||
|
||||
helper((2, 8, 4, 5), torch.int16)
|
||||
|
||||
class TestLogical(TestCase):
|
||||
class TestLogical(TestCaseMPS):
|
||||
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
|
||||
return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
|
|
@ -2591,7 +2722,7 @@ class TestLogical(TestCase):
|
|||
|
||||
[helper(dtype) for dtype in [torch.float32, torch.float16, torch.int32, torch.int16, torch.uint8, torch.int8, torch.bool]]
|
||||
|
||||
class TestSmoothL1Loss(TestCase):
|
||||
class TestSmoothL1Loss(TestCaseMPS):
|
||||
|
||||
def _smooth_l1_loss_helper(self, reduction="mean", requires_grad=False):
|
||||
# CPU
|
||||
|
|
@ -2630,7 +2761,7 @@ class TestSmoothL1Loss(TestCase):
|
|||
self._smooth_l1_loss_helper(reduction="sum", requires_grad=True)
|
||||
|
||||
|
||||
class TestNLLLoss(TestCase):
|
||||
class TestNLLLoss(TestCaseMPS):
|
||||
def test_nll_loss_mismatched_batch(self, device='mps'):
|
||||
x = torch.randn((10, 3), requires_grad=True, device=device)
|
||||
# t should have size (10,)
|
||||
|
|
@ -6031,6 +6162,27 @@ class TestNLLLoss(TestCase):
|
|||
x.backward(torch.randn_like(x))
|
||||
torch.mps.synchronize()
|
||||
|
||||
def test_mps_allocator_module(self):
|
||||
# first garbage collect and empty the cached blocks
|
||||
gc.collect()
|
||||
torch.mps.empty_cache()
|
||||
# measure memory allocations from MPSAllocator
|
||||
current_alloc_before = torch.mps.current_allocated_memory()
|
||||
# after garbage collection and emptying the cache the
|
||||
# current_allocated_memory must be zero
|
||||
self.assertTrue(current_alloc_before == 0)
|
||||
# measure total memory allocations from Metal driver
|
||||
driver_alloc_before = torch.mps.driver_allocated_memory()
|
||||
# allocate a new 8 MB tensor to force allocation of a new Metal Heap
|
||||
x = torch.ones(1024 * 1024 * 8, device="mps")
|
||||
# get memory allocations after allocating tensor x
|
||||
current_alloc_after = torch.mps.current_allocated_memory()
|
||||
driver_alloc_after = torch.mps.driver_allocated_memory()
|
||||
# current and driver memory allocations must have
|
||||
# grown at this point
|
||||
self.assertTrue(current_alloc_after > current_alloc_before)
|
||||
self.assertTrue(driver_alloc_after > driver_alloc_before)
|
||||
|
||||
# Test random_.to and random_.from
|
||||
def test_random(self):
|
||||
def helper(shape, low, high, dtype=torch.int32):
|
||||
|
|
@ -6525,7 +6677,7 @@ class TestNNMPS(NNTestCase):
|
|||
# self.assertEqual(expect, actual)
|
||||
|
||||
|
||||
class TestConstantPadNd(TestCase):
|
||||
class TestConstantPadNd(TestCaseMPS):
|
||||
def test_preserves_memory_format(self):
|
||||
nchw_tensor = torch.rand((1, 2, 5, 3))
|
||||
nchw_padded = torch.constant_pad_nd(nchw_tensor, [1, 2], 0.5)
|
||||
|
|
@ -6536,7 +6688,7 @@ class TestConstantPadNd(TestCase):
|
|||
self.assertTrue(nhwc_padded.is_contiguous(memory_format=torch.channels_last))
|
||||
|
||||
|
||||
class TestLinalgMPS(TestCase):
|
||||
class TestLinalgMPS(TestCaseMPS):
|
||||
def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out=False):
|
||||
dtype = t.dtype
|
||||
numpy_dtype = dtype
|
||||
|
|
@ -6602,7 +6754,7 @@ class TestLinalgMPS(TestCase):
|
|||
m2 = torch.randn(25, device=device).to(dtype)
|
||||
self._test_addr(torch.addr, M, m1, m2, beta=0)
|
||||
|
||||
class TestGatherScatter(TestCase):
|
||||
class TestGatherScatter(TestCaseMPS):
|
||||
def test_slicing_with_step(self):
|
||||
# Slicing with step
|
||||
# https://github.com/pytorch/pytorch/issues/78886
|
||||
|
|
@ -6667,7 +6819,7 @@ class TestGatherScatter(TestCase):
|
|||
# They are subset of those tests as currently only this subset is working.
|
||||
# This whole `class` will be removed when we add generic device testing. There
|
||||
# are no additional tests added apart from what is part of test_view_ops.py
|
||||
class TestViewOpsMPS(TestCase):
|
||||
class TestViewOpsMPS(TestCaseMPS):
|
||||
exact_dtype = True
|
||||
|
||||
def test_permute_slicing(self):
|
||||
|
|
@ -7478,7 +7630,7 @@ class TestViewOpsMPS(TestCase):
|
|||
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
|
||||
self.assertEqual(x.view(6).shape, [6])
|
||||
|
||||
class TestConvolutionMPS(TestCase):
|
||||
class TestConvolutionMPS(TestCaseMPS):
|
||||
def test_conv1d_all_strides_paddings(self):
|
||||
# https://github.com/pytorch/pytorch/issues/82921
|
||||
def helper(stride, padding):
|
||||
|
|
@ -7837,7 +7989,7 @@ class TestConvolutionMPS(TestCase):
|
|||
msg="groundtruth comparison failed for mode={}, "
|
||||
"padding_mode={}".format(mode, padding_mode))
|
||||
|
||||
class TestAdvancedIndexing(TestCase):
|
||||
class TestAdvancedIndexing(TestCaseMPS):
|
||||
supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]
|
||||
supported_np_dtypes = [np.float32, np.float16, np.int64, np.int32, np.int16, np.uint8]
|
||||
|
||||
|
|
@ -8641,7 +8793,7 @@ class TestAdvancedIndexing(TestCase):
|
|||
out = x[idx] # index
|
||||
self.assertEqual(out, torch.zeros(2, device=device), atol=0, rtol=0)
|
||||
|
||||
class TestRNNMPS(TestCase):
|
||||
class TestRNNMPS(TestCaseMPS):
|
||||
def test_lstm_1(self, device="mps", dtype=torch.float32):
|
||||
|
||||
rnn = nn.LSTM(1, 4, 2, device="cpu")
|
||||
|
|
@ -8851,7 +9003,7 @@ for t in [torch.double, torch.cdouble, torch.cfloat, torch.int8, torch.bfloat16]
|
|||
del MPS_DTYPES[MPS_DTYPES.index(t)]
|
||||
|
||||
|
||||
class TestConsistency(TestCase):
|
||||
class TestConsistency(TestCaseMPS):
|
||||
# TODO: This is only used while some ops are being added.
|
||||
# This list should contain all ops and dtypes eventually
|
||||
# This can be generated automatically in the `new_mps_allowlist.txt` file
|
||||
|
|
|
|||
|
|
@ -1201,8 +1201,12 @@ def _multiprocessing_init() -> None: ...
|
|||
# Defined in torch/csrc/mps/Module.cpp
|
||||
def _mps_synchronize() -> None: ...
|
||||
def _mps_get_default_generator() -> Generator: ...
|
||||
def _is_mps_available() -> _bool: ...
|
||||
def _is_mps_on_macos_13_or_newer() -> _bool: ...
|
||||
def _mps_emptyCache() -> None: ...
|
||||
def _mps_setMemoryFraction(fraction: _float) -> None: ...
|
||||
def _mps_currentAllocatedMemory() -> _int: ...
|
||||
def _mps_driverAllocatedMemory() -> _int: ...
|
||||
def _mps_is_available() -> _bool: ...
|
||||
def _mps_is_on_macos_13_or_newer() -> _bool: ...
|
||||
|
||||
# Defined in torch/csrc/cuda/Module.cpp
|
||||
def _cuda_getCurrentStream(device: _int) -> Tuple: ...
|
||||
|
|
|
|||
|
|
@ -15,13 +15,13 @@ def is_built() -> bool:
|
|||
@_lru_cache()
|
||||
def is_available() -> bool:
|
||||
r"""Returns a bool indicating if MPS is currently available."""
|
||||
return torch._C._is_mps_available()
|
||||
return torch._C._mps_is_available()
|
||||
|
||||
|
||||
@_lru_cache()
|
||||
def is_macos13_or_newer() -> bool:
|
||||
r"""Returns a bool indicating whether MPS is running on MacOS 13 or newer."""
|
||||
return torch._C._is_mps_on_macos_13_or_newer()
|
||||
return torch._C._mps_is_on_macos_13_or_newer()
|
||||
|
||||
|
||||
# Register prims as implementation of var_mean and group_norm
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#include <ATen/ATen.h>
|
||||
#include <c10/util/CallOnce.h>
|
||||
#include <torch/csrc/Generator.h>
|
||||
#include <torch/csrc/THP.h>
|
||||
#include <torch/csrc/python_headers.h>
|
||||
#include <torch/csrc/utils/python_numbers.h>
|
||||
|
||||
|
|
@ -77,14 +78,51 @@ static PyObject* MPSModule_synchronize(PyObject* _unused, PyObject* noargs) {
|
|||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* MPSModule_emptyCache(PyObject* _unused, PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
at::detail::getMPSHooks().emptyCache();
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* MPSModule_setMemoryFraction(
|
||||
PyObject* _unused,
|
||||
PyObject* args) {
|
||||
HANDLE_TH_ERRORS
|
||||
THPUtils_assert(
|
||||
THPUtils_checkDouble(args), "invalid argument to setMemoryFraction()");
|
||||
double fraction = THPUtils_unpackDouble(args);
|
||||
at::detail::getMPSHooks().setMemoryFraction(fraction);
|
||||
END_HANDLE_TH_ERRORS
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* MPSModule_currentAllocatedMemory(
|
||||
PyObject* _unused,
|
||||
PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
return PyLong_FromUnsignedLongLong(
|
||||
at::detail::getMPSHooks().getCurrentAllocatedMemory());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* MPSModule_driverAllocatedMemory(
|
||||
PyObject* _unused,
|
||||
PyObject* noargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
return PyLong_FromUnsignedLongLong(
|
||||
at::detail::getMPSHooks().getDriverAllocatedMemory());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(modernize-avoid-c-arrays,
|
||||
// cppcoreguidelines-avoid-non-const-global-variables,
|
||||
// cppcoreguidelines-avoid-c-arrays)
|
||||
static struct PyMethodDef _MPSModule_methods[] = {
|
||||
{"_mps_synchronize", MPSModule_synchronize, METH_NOARGS, nullptr},
|
||||
{"_mps_is_in_bad_fork", MPSModule_isInBadFork, METH_NOARGS, nullptr},
|
||||
{"_is_mps_available", MPSModule_isAvailable, METH_NOARGS, nullptr},
|
||||
{"_is_mps_on_macos_13_or_newer",
|
||||
{"_mps_is_available", MPSModule_isAvailable, METH_NOARGS, nullptr},
|
||||
{"_mps_is_on_macos_13_or_newer",
|
||||
MPSModule_isMacOS13orNewer,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
|
|
@ -92,6 +130,16 @@ static struct PyMethodDef _MPSModule_methods[] = {
|
|||
MPSModule_getDefaultMPSGenerator,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"_mps_emptyCache", MPSModule_emptyCache, METH_NOARGS, nullptr},
|
||||
{"_mps_setMemoryFraction", MPSModule_setMemoryFraction, METH_O, nullptr},
|
||||
{"_mps_currentAllocatedMemory",
|
||||
MPSModule_currentAllocatedMemory,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"_mps_driverAllocatedMemory",
|
||||
MPSModule_driverAllocatedMemory,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{nullptr}};
|
||||
|
||||
PyMethodDef* python_functions() {
|
||||
|
|
|
|||
|
|
@ -50,5 +50,55 @@ def seed() -> None:
|
|||
r"""Sets the seed for generating random numbers to a random number."""
|
||||
_get_default_mps_generator().seed()
|
||||
|
||||
def empty_cache() -> None:
|
||||
r"""Releases all unoccupied cached memory currently held by the caching
|
||||
allocator so that those can be used in other GPU applications.
|
||||
"""
|
||||
torch._C._mps_emptyCache()
|
||||
|
||||
def set_per_process_memory_fraction(fraction) -> None:
|
||||
r"""Set memory fraction for limiting process's memory allocation on MPS device.
|
||||
The allowed value equals the fraction multiplied by recommended maximum device memory
|
||||
(obtained from Metal API device.recommendedMaxWorkingSetSize).
|
||||
If trying to allocate more than the allowed value in a process, it will raise an out of
|
||||
memory error in allocator.
|
||||
|
||||
Args:
|
||||
fraction(float): Range: 0~2. Allowed memory equals total_memory * fraction.
|
||||
|
||||
.. note::
|
||||
Passing 0 to fraction means unlimited allocations
|
||||
(may cause system failure if out of memory).
|
||||
Passing fraction greater than 1.0 allows limits beyond the value
|
||||
returned from device.recommendedMaxWorkingSetSize.
|
||||
"""
|
||||
|
||||
if not isinstance(fraction, float):
|
||||
raise TypeError('Invalid type for fraction argument, must be `float`')
|
||||
if fraction < 0 or fraction > 2:
|
||||
raise ValueError('Invalid fraction value: {}. Allowed range: 0~2'.format(fraction))
|
||||
|
||||
torch._C._mps_setMemoryFraction(fraction)
|
||||
|
||||
def current_allocated_memory() -> int:
|
||||
r"""Returns the current GPU memory occupied by tensors in bytes.
|
||||
|
||||
.. note::
|
||||
The returned size does not include cached allocations in
|
||||
memory pools of MPSAllocator.
|
||||
"""
|
||||
return torch._C._mps_currentAllocatedMemory()
|
||||
|
||||
def driver_allocated_memory() -> int:
|
||||
r"""Returns total GPU memory allocated by Metal driver for the process in bytes.
|
||||
|
||||
.. note::
|
||||
The returned size includes cached allocations in MPSAllocator pools
|
||||
as well as allocations from MPS/MPSGraph frameworks.
|
||||
"""
|
||||
return torch._C._mps_driverAllocatedMemory()
|
||||
|
||||
__all__ = [
|
||||
'get_rng_state', 'manual_seed', 'seed', 'set_rng_state', 'synchronize']
|
||||
'get_rng_state', 'manual_seed', 'seed', 'set_rng_state', 'synchronize',
|
||||
'empty_cache', 'set_per_process_memory_fraction', 'current_allocated_memory',
|
||||
'driver_allocated_memory']
|
||||
|
|
|
|||
Loading…
Reference in a new issue