mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Stop ignoring mypy errors in torch/testing/_internal/common_utils.py (#144483)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144483 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
ea3395e4f2
commit
8ad37ed710
3 changed files with 61 additions and 59 deletions
|
|
@ -276,7 +276,7 @@ class GraphInfoProvider:
|
|||
vmin=min(self.get_knapsack_memory_input()),
|
||||
vmax=max(self.get_knapsack_memory_input()),
|
||||
)
|
||||
cmap = cm.viridis
|
||||
cmap = cm.viridis # type: ignore[attr-defined]
|
||||
|
||||
# Assign colors based on memory
|
||||
node_colors = [
|
||||
|
|
|
|||
|
|
@ -72,21 +72,21 @@ def is_metal_capture_enabled() -> bool:
|
|||
"""Checks if `metal_capture` context manager is usable
|
||||
To enable metal capture, set MTL_CAPTURE_ENABLED envvar
|
||||
"""
|
||||
return torch._C._mps_isCaptureEnabled()
|
||||
return torch._C._mps_isCaptureEnabled() # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def is_capturing_metal() -> bool:
|
||||
"""Cheks if metal capture is in progress"""
|
||||
return torch._C._mps_isCapturing()
|
||||
return torch._C._mps_isCapturing() # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def metal_capture(fname: str):
|
||||
"""Conext manager that enables capturing of Metal calls into gputrace"""
|
||||
try:
|
||||
torch._C._mps_startCapture(fname)
|
||||
torch._C._mps_startCapture(fname) # type: ignore[attr-defined]
|
||||
yield
|
||||
# Drain all the work that were enqueued during the context call
|
||||
torch.mps.synchronize()
|
||||
finally:
|
||||
torch._C._mps_stopCapture()
|
||||
torch._C._mps_stopCapture() # type: ignore[attr-defined]
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
# mypy: ignore-errors
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
r"""Importing this file must **not** initialize CUDA context. test_distributed
|
||||
relies on this assumption to properly run. This means that when this is imported
|
||||
|
|
@ -156,6 +156,7 @@ class TestEnvironment:
|
|||
implied_by_fn=lambda: False,
|
||||
):
|
||||
enabled = default
|
||||
env_var_val = None
|
||||
if env_var is not None:
|
||||
env_var_val = os.getenv(env_var)
|
||||
enabled = enabled_fn(env_var_val, default)
|
||||
|
|
@ -351,17 +352,15 @@ def get_tracked_input() -> Optional[TrackedInput]:
|
|||
test_fn = extract_test_fn()
|
||||
if test_fn is None:
|
||||
return None
|
||||
if not hasattr(test_fn, "tracked_input"):
|
||||
return None
|
||||
return test_fn.tracked_input
|
||||
return getattr(test_fn, "tracked_input", None)
|
||||
|
||||
def clear_tracked_input():
|
||||
def clear_tracked_input() -> None:
|
||||
test_fn = extract_test_fn()
|
||||
if test_fn is None:
|
||||
return
|
||||
if not hasattr(test_fn, "tracked_input"):
|
||||
return None
|
||||
test_fn.tracked_input = None
|
||||
return
|
||||
test_fn.tracked_input = None # type: ignore[attr-defined]
|
||||
|
||||
# Wraps an iterator and tracks the most recent value the iterator produces
|
||||
# for debugging purposes. Tracked values are stored on the test function.
|
||||
|
|
@ -428,7 +427,7 @@ class TrackedInputIter:
|
|||
return
|
||||
if not hasattr(self.test_fn, "tracked_input"):
|
||||
return
|
||||
self.test_fn.tracked_input = tracked_input
|
||||
self.test_fn.tracked_input = tracked_input # type: ignore[attr-defined]
|
||||
|
||||
class _TestParametrizer:
|
||||
"""
|
||||
|
|
@ -690,7 +689,7 @@ class parametrize(_TestParametrizer):
|
|||
for idx, values in enumerate(self.arg_values):
|
||||
maybe_name = None
|
||||
|
||||
decorators = []
|
||||
decorators: List[Any] = []
|
||||
if isinstance(values, subtest):
|
||||
sub = values
|
||||
values = sub.arg_values
|
||||
|
|
@ -705,7 +704,7 @@ class parametrize(_TestParametrizer):
|
|||
else:
|
||||
gen_test = test
|
||||
|
||||
values = list(values) if len(self.arg_names) > 1 else [values]
|
||||
values = list(values) if len(self.arg_names) > 1 else [values] # type: ignore[call-overload]
|
||||
if len(values) != len(self.arg_names):
|
||||
raise RuntimeError(f'Expected # values == # arg names, but got: {len(values)} '
|
||||
f'values and {len(self.arg_names)} names for test "{test.__name__}"')
|
||||
|
|
@ -863,6 +862,8 @@ def cppProfilingFlagsToProfilingMode():
|
|||
|
||||
@contextmanager
|
||||
def enable_profiling_mode_for_profiling_tests():
|
||||
old_prof_exec_state = False
|
||||
old_prof_mode_state = False
|
||||
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
|
||||
old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
|
||||
old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
|
||||
|
|
@ -1175,7 +1176,7 @@ def sanitize_pytest_xml(xml_file: str):
|
|||
def get_pytest_test_cases(argv: List[str]) -> List[str]:
|
||||
class TestCollectorPlugin:
|
||||
def __init__(self) -> None:
|
||||
self.tests = []
|
||||
self.tests: List[Any] = []
|
||||
|
||||
def pytest_collection_finish(self, session):
|
||||
for item in session.items:
|
||||
|
|
@ -1286,6 +1287,7 @@ def run_tests(argv=UNITTEST_ARGS):
|
|||
assert not failed, "Some test shards have failed"
|
||||
elif USE_PYTEST:
|
||||
pytest_args = argv + ["--use-main-module"]
|
||||
test_report_path = ""
|
||||
if TEST_SAVE_XML:
|
||||
test_report_path = get_report_path(pytest=True)
|
||||
print(f'Test results will be stored in {test_report_path}')
|
||||
|
|
@ -1411,7 +1413,7 @@ else:
|
|||
def is_privateuse1_backend_available():
|
||||
privateuse1_backend_name = torch._C._get_privateuse1_backend_name()
|
||||
privateuse1_backend_module = getattr(torch, privateuse1_backend_name, None)
|
||||
return hasattr(privateuse1_backend_module, "is_available") and privateuse1_backend_module.is_available()
|
||||
return (is_available := getattr(privateuse1_backend_module, "is_available", None)) and is_available()
|
||||
|
||||
|
||||
IS_FILESYSTEM_UTF8_ENCODING = sys.getfilesystemencoding() == 'utf-8'
|
||||
|
|
@ -1628,8 +1630,8 @@ def skipIfTorchDynamo(msg="test doesn't currently work with dynamo"):
|
|||
|
||||
assert isinstance(fn, type)
|
||||
if TEST_WITH_TORCHDYNAMO:
|
||||
fn.__unittest_skip__ = True
|
||||
fn.__unittest_skip_why__ = msg
|
||||
fn.__unittest_skip__ = True # type: ignore[attr-defined]
|
||||
fn.__unittest_skip_why__ = msg # type: ignore[attr-defined]
|
||||
|
||||
return fn
|
||||
|
||||
|
|
@ -1649,8 +1651,8 @@ def skipIfTorchInductor(msg="test doesn't currently work with torchinductor",
|
|||
|
||||
assert isinstance(fn, type)
|
||||
if condition:
|
||||
fn.__unittest_skip__ = True
|
||||
fn.__unittest_skip_why__ = msg
|
||||
fn.__unittest_skip__ = True # type: ignore[attr-defined]
|
||||
fn.__unittest_skip_why__ = msg # type: ignore[attr-defined]
|
||||
|
||||
return fn
|
||||
|
||||
|
|
@ -1723,8 +1725,8 @@ def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT exe
|
|||
|
||||
assert isinstance(fn, type)
|
||||
if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
|
||||
fn.__unittest_skip__ = True
|
||||
fn.__unittest_skip_why__ = msg
|
||||
fn.__unittest_skip__ = True # type: ignore[attr-defined]
|
||||
fn.__unittest_skip_why__ = msg # type: ignore[attr-defined]
|
||||
|
||||
return fn
|
||||
|
||||
|
|
@ -1832,8 +1834,8 @@ def skipIfNNModuleInlined(
|
|||
|
||||
assert isinstance(fn, type)
|
||||
if condition:
|
||||
fn.__unittest_skip__ = True
|
||||
fn.__unittest_skip_why__ = msg
|
||||
fn.__unittest_skip__ = True # type: ignore[attr-defined]
|
||||
fn.__unittest_skip_why__ = msg # type: ignore[attr-defined]
|
||||
|
||||
return fn
|
||||
|
||||
|
|
@ -2013,17 +2015,17 @@ class DeterministicGuard:
|
|||
def __enter__(self):
|
||||
self.deterministic_restore = torch.are_deterministic_algorithms_enabled()
|
||||
self.warn_only_restore = torch.is_deterministic_algorithms_warn_only_enabled()
|
||||
self.fill_uninitialized_memory_restore = torch.utils.deterministic.fill_uninitialized_memory
|
||||
self.fill_uninitialized_memory_restore = torch.utils.deterministic.fill_uninitialized_memory # type: ignore[attr-defined]
|
||||
torch.use_deterministic_algorithms(
|
||||
self.deterministic,
|
||||
warn_only=self.warn_only)
|
||||
torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory
|
||||
torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory # type: ignore[attr-defined]
|
||||
|
||||
def __exit__(self, exception_type, exception_value, traceback):
|
||||
torch.use_deterministic_algorithms(
|
||||
self.deterministic_restore,
|
||||
warn_only=self.warn_only_restore)
|
||||
torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory_restore
|
||||
torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory_restore # type: ignore[attr-defined]
|
||||
|
||||
class AlwaysWarnTypedStorageRemoval:
|
||||
def __init__(self, always_warn):
|
||||
|
|
@ -2176,7 +2178,7 @@ def _test_function(fn, device):
|
|||
def skipIfNoXNNPACK(fn):
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not torch.backends.xnnpack.enabled:
|
||||
if not torch.backends.xnnpack.enabled: # type: ignore[attr-defined]
|
||||
raise unittest.SkipTest('XNNPACK must be enabled for these tests. Please build with USE_XNNPACK=1.')
|
||||
else:
|
||||
fn(*args, **kwargs)
|
||||
|
|
@ -2280,7 +2282,7 @@ def to_gpu(obj, type_map=None):
|
|||
res.requires_grad = obj.requires_grad
|
||||
return res
|
||||
elif torch.is_storage(obj):
|
||||
return obj.new().resize_(obj.size()).copy_(obj)
|
||||
return obj.new().resize_(obj.size()).copy_(obj) # type: ignore[attr-defined, union-attr]
|
||||
elif isinstance(obj, list):
|
||||
return [to_gpu(o, type_map) for o in obj]
|
||||
elif isinstance(obj, tuple):
|
||||
|
|
@ -2482,13 +2484,13 @@ class CudaMemoryLeakCheck:
|
|||
if not discrepancy_detected:
|
||||
continue
|
||||
|
||||
if caching_allocator_discrepancy and not driver_discrepancy:
|
||||
if caching_allocator_discrepancy and not driver_discrepancy: # type: ignore[possibly-undefined]
|
||||
# Just raises a warning if the leak is not validated by the
|
||||
# driver API
|
||||
# NOTE: this may be a problem with how the caching allocator collects its
|
||||
# statistics or a leak too small to trigger the allocation of an
|
||||
# additional block of memory by the CUDA driver
|
||||
msg = ("CUDA caching allocator reports a memory leak not "
|
||||
msg = ("CUDA caching allocator reports a memory leak not " # type: ignore[possibly-undefined]
|
||||
f"verified by the driver API in {self.name}! "
|
||||
f"Caching allocator allocated memory was {self.caching_allocator_befores[i]} "
|
||||
f"and is now reported as {caching_allocator_mem_allocated} "
|
||||
|
|
@ -2498,7 +2500,7 @@ class CudaMemoryLeakCheck:
|
|||
elif caching_allocator_discrepancy and driver_discrepancy:
|
||||
# A caching allocator discrepancy validated by the driver API is a
|
||||
# failure (except on ROCm, see below)
|
||||
msg = (f"CUDA driver API confirmed a leak in {self.name}! "
|
||||
msg = (f"CUDA driver API confirmed a leak in {self.name}! " # type: ignore[possibly-undefined]
|
||||
f"Caching allocator allocated memory was {self.caching_allocator_befores[i]} "
|
||||
f"and is now reported as {caching_allocator_mem_allocated} "
|
||||
f"on device {i}. "
|
||||
|
|
@ -3022,7 +3024,7 @@ class TestCase(expecttest.TestCase):
|
|||
lambda repro_parts=repro_parts: print_repro_on_failure(repro_parts))
|
||||
except Exception as e:
|
||||
# Don't fail entirely if we can't get the test filename
|
||||
log.info("could not print repro string", extra=str(e))
|
||||
log.info("could not print repro string", extra=str(e)) # type: ignore[arg-type]
|
||||
|
||||
def assertLeaksNoCudaTensors(self, name=None):
|
||||
name = self.id() if name is None else name
|
||||
|
|
@ -3133,7 +3135,7 @@ class TestCase(expecttest.TestCase):
|
|||
using_unittest = isinstance(result, unittest.TestResult)
|
||||
|
||||
super_run = super().run
|
||||
test_cls = super_run.__self__
|
||||
test_cls = super_run.__self__ # type: ignore[attr-defined]
|
||||
|
||||
# Are we compiling?
|
||||
compiled = TEST_WITH_TORCHDYNAMO or TEST_WITH_AOT_EAGER or TEST_WITH_TORCHINDUCTOR
|
||||
|
|
@ -3244,9 +3246,9 @@ class TestCase(expecttest.TestCase):
|
|||
# Create dummy TestInfo to record results correctly
|
||||
from xmlrunner.result import _TestInfo # type: ignore[import]
|
||||
case = _TestInfo(result, case)
|
||||
case.output = _TestInfo.ERROR
|
||||
case.elapsed_time = 0.0
|
||||
case.test_description = "TestSuiteEarlyFailure"
|
||||
case.output = _TestInfo.ERROR # type: ignore[attr-defined]
|
||||
case.elapsed_time = 0.0 # type: ignore[attr-defined]
|
||||
case.test_description = "TestSuiteEarlyFailure" # type: ignore[attr-defined]
|
||||
# This shouldn't really happen, but if does add fake failure
|
||||
# For more details see https://github.com/pytorch/pytorch/issues/71973
|
||||
result.failures.append((case, "TestSuite execution was aborted early"))
|
||||
|
|
@ -3680,7 +3682,7 @@ class TestCase(expecttest.TestCase):
|
|||
return get_sparse_data_with_block(pattern, blocksize)
|
||||
|
||||
# batch data is created recursively:
|
||||
batch_data = {}
|
||||
batch_data = {} # type: ignore[var-annotated]
|
||||
for i, item in enumerate(pattern):
|
||||
for layout, d in get_batch_sparse_data(item, blocksize).items():
|
||||
target = batch_data.get(layout)
|
||||
|
|
@ -3842,30 +3844,30 @@ class TestCase(expecttest.TestCase):
|
|||
for blocksize in blocksizes:
|
||||
for densesize in densesizes:
|
||||
if layout == torch.strided:
|
||||
indices = ()
|
||||
indices = () # type: ignore[assignment]
|
||||
values = torch.empty((basesize + densesize), device=device, dtype=dtype)
|
||||
elif layout == torch.sparse_coo:
|
||||
indices = (torch.empty(len(basesize), 0, device=device, dtype=index_dtype),)
|
||||
indices = (torch.empty(len(basesize), 0, device=device, dtype=index_dtype),) # type: ignore[assignment]
|
||||
values = torch.empty((0, *densesize), device=device, dtype=dtype)
|
||||
elif layout == torch.sparse_csr:
|
||||
crow_indices = torch.tensor([0] * (basesize[0] + 1), device=device, dtype=index_dtype)
|
||||
col_indices = torch.empty(0, device=device, dtype=index_dtype)
|
||||
indices = (crow_indices, col_indices)
|
||||
indices = (crow_indices, col_indices) # type: ignore[assignment]
|
||||
values = torch.empty((0, *densesize), device=device, dtype=dtype)
|
||||
elif layout == torch.sparse_csc:
|
||||
ccol_indices = torch.tensor([0] * (basesize[1] + 1), device=device, dtype=index_dtype)
|
||||
row_indices = torch.empty(0, device=device, dtype=index_dtype)
|
||||
indices = (ccol_indices, row_indices)
|
||||
indices = (ccol_indices, row_indices) # type: ignore[assignment]
|
||||
values = torch.empty((0, *densesize), device=device, dtype=dtype)
|
||||
elif layout == torch.sparse_bsr:
|
||||
crow_indices = torch.tensor([0] * (basesize[0] // blocksize[0] + 1), device=device, dtype=index_dtype)
|
||||
col_indices = torch.empty(0, device=device, dtype=index_dtype)
|
||||
indices = (crow_indices, col_indices)
|
||||
indices = (crow_indices, col_indices) # type: ignore[assignment]
|
||||
values = torch.empty((0, *blocksize, *densesize), device=device, dtype=dtype)
|
||||
elif layout == torch.sparse_bsc:
|
||||
ccol_indices = torch.tensor([0] * (basesize[1] // blocksize[1] + 1), device=device, dtype=index_dtype)
|
||||
row_indices = torch.empty(0, device=device, dtype=index_dtype)
|
||||
indices = (ccol_indices, row_indices)
|
||||
indices = (ccol_indices, row_indices) # type: ignore[assignment]
|
||||
values = torch.empty((0, *blocksize, *densesize), device=device, dtype=dtype)
|
||||
else:
|
||||
assert 0 # unreachable
|
||||
|
|
@ -4029,9 +4031,9 @@ class TestCase(expecttest.TestCase):
|
|||
|
||||
if error_metas:
|
||||
# See [ErrorMeta Cycles]
|
||||
error_metas = [error_metas]
|
||||
error_metas = [error_metas] # type: ignore[list-item]
|
||||
# TODO: compose all metas into one AssertionError
|
||||
raise error_metas.pop()[0].to_error(
|
||||
raise error_metas.pop()[0].to_error( # type: ignore[index]
|
||||
# This emulates unittest.TestCase's behavior if a custom message passed and
|
||||
# TestCase.longMessage (https://docs.python.org/3/library/unittest.html#unittest.TestCase.longMessage)
|
||||
# is True (default)
|
||||
|
|
@ -4062,7 +4064,7 @@ class TestCase(expecttest.TestCase):
|
|||
context: Optional[AssertRaisesContextIgnoreNotImplementedError] = \
|
||||
AssertRaisesContextIgnoreNotImplementedError(expected_exception, self) # type: ignore[call-arg]
|
||||
try:
|
||||
return context.handle('assertRaises', args, kwargs) # type: ignore[union-attr]
|
||||
return context.handle('assertRaises', args, kwargs) # type: ignore[union-attr, arg-type]
|
||||
finally:
|
||||
# see https://bugs.python.org/issue23890
|
||||
context = None
|
||||
|
|
@ -4087,7 +4089,7 @@ class TestCase(expecttest.TestCase):
|
|||
if self._ignore_not_implemented_error:
|
||||
context = AssertRaisesContextIgnoreNotImplementedError( # type: ignore[call-arg]
|
||||
expected_exception, self, expected_regex)
|
||||
return context.handle('assertRaisesRegex', args, kwargs) # type: ignore[attr-defined]
|
||||
return context.handle('assertRaisesRegex', args, kwargs) # type: ignore[attr-defined, arg-type]
|
||||
else:
|
||||
return super().assertRaisesRegex(expected_exception, expected_regex, *args, **kwargs)
|
||||
|
||||
|
|
@ -4185,8 +4187,8 @@ class TestCase(expecttest.TestCase):
|
|||
# test/common_utils.py, but it matters in onnx-pytorch
|
||||
module_id = self.__class__.__module__
|
||||
munged_id = remove_prefix(self.id(), module_id + ".")
|
||||
test_file = os.path.realpath(sys.modules[module_id].__file__)
|
||||
expected_file = os.path.join(os.path.dirname(test_file),
|
||||
test_file = os.path.realpath(sys.modules[module_id].__file__) # type: ignore[type-var]
|
||||
expected_file = os.path.join(os.path.dirname(test_file), # type: ignore[type-var, arg-type]
|
||||
"expect",
|
||||
munged_id)
|
||||
|
||||
|
|
@ -4284,7 +4286,7 @@ class TestCase(expecttest.TestCase):
|
|||
if attrs.get("operator") == operator:
|
||||
break
|
||||
|
||||
self.assertEqual(attrs["operator"], operator)
|
||||
self.assertEqual(attrs["operator"], operator) # type: ignore[possibly-undefined]
|
||||
self.assertEqual(attrs.get("overload_name", ""), overload_name)
|
||||
|
||||
def check_nondeterministic_alert(self, fn, caller_name, should_alert=True):
|
||||
|
|
@ -4702,7 +4704,7 @@ def random_lowrank_matrix(rank, rows, columns, *batch_dims, **kwargs):
|
|||
|
||||
def _generate_indices_prefer_all_rows(rows: int, cols: int, num_indices: int) -> torch.Tensor:
|
||||
"""Generate indices for a row x cols matrix, preferring at least one index per row if possible."""
|
||||
indices = []
|
||||
indices = [] # type: ignore[var-annotated]
|
||||
n_per_row = math.ceil(num_indices / rows)
|
||||
col_indices = list(range(cols))
|
||||
|
||||
|
|
@ -4854,7 +4856,7 @@ def do_test_empty_full(self, dtypes, layout, device):
|
|||
int64_dtype, layout, device, fv + 5, False)
|
||||
|
||||
# FIXME: improve load_tests() documentation here
|
||||
running_script_path = None
|
||||
running_script_path = None # type: ignore[var-annotated]
|
||||
def set_running_script_path():
|
||||
global running_script_path
|
||||
try:
|
||||
|
|
@ -5116,7 +5118,7 @@ def copy_func(f):
|
|||
argdefs=f.__defaults__,
|
||||
closure=f.__closure__)
|
||||
g = functools.update_wrapper(g, f)
|
||||
g.__kwdefaults__ = f.__kwdefaults__
|
||||
g.__kwdefaults__ = f.__kwdefaults__ # type: ignore[attr-defined]
|
||||
return g
|
||||
|
||||
|
||||
|
|
@ -5304,8 +5306,8 @@ class TestGradients(TestCase):
|
|||
if is_iterable_of_tensors(sample.input):
|
||||
all_args = chain(sample.input, sample.args, sample.kwargs.values())
|
||||
else:
|
||||
all_args = tuple(chain((sample.input,), sample.args, sample.kwargs.values()))
|
||||
gradcheck_args = tuple(x for x in all_args if (isinstance(x, torch.Tensor) and x.requires_grad))
|
||||
all_args = tuple(chain((sample.input,), sample.args, sample.kwargs.values())) # type: ignore[assignment]
|
||||
gradcheck_args = tuple(x for x in all_args if (isinstance(x, torch.Tensor) and x.requires_grad)) # type: ignore[union-attr]
|
||||
|
||||
# Verifies sample input tensors should have no grad
|
||||
# This may happen if the same tensor is used in two different SampleInputs
|
||||
|
|
@ -5525,7 +5527,7 @@ def check_leaked_tensors(limit=1, matched_type=torch.Tensor):
|
|||
try:
|
||||
gc.collect()
|
||||
gc.set_debug(gc.DEBUG_SAVEALL)
|
||||
garbage_objs = []
|
||||
garbage_objs = [] # type: ignore[var-annotated]
|
||||
|
||||
# run the user code, after cleaning any existing refcycles, and then check for new ones
|
||||
# also allow usercode to check the garbage objs (e.g. for assertion) after exiting ctxmgr
|
||||
|
|
@ -5539,7 +5541,7 @@ def check_leaked_tensors(limit=1, matched_type=torch.Tensor):
|
|||
f"{num_garbage_objs} tensors were found in the garbage. Did you introduce a reference cycle?"
|
||||
)
|
||||
try:
|
||||
import objgraph
|
||||
import objgraph # type: ignore[import-not-found]
|
||||
warnings.warn(
|
||||
f"Dumping first {limit} objgraphs of leaked {matched_type}s rendered to png"
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue