Stop ignoring mypy errors in torch/testing/_internal/common_utils.py (#144483)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144483
Approved by: https://github.com/Skylion007
This commit is contained in:
Aaron Orenstein 2025-01-13 17:23:07 -08:00 committed by PyTorch MergeBot
parent ea3395e4f2
commit 8ad37ed710
3 changed files with 61 additions and 59 deletions

View file

@ -276,7 +276,7 @@ class GraphInfoProvider:
vmin=min(self.get_knapsack_memory_input()),
vmax=max(self.get_knapsack_memory_input()),
)
cmap = cm.viridis
cmap = cm.viridis # type: ignore[attr-defined]
# Assign colors based on memory
node_colors = [

View file

@ -72,21 +72,21 @@ def is_metal_capture_enabled() -> bool:
"""Checks if `metal_capture` context manager is usable
To enable metal capture, set MTL_CAPTURE_ENABLED envvar
"""
return torch._C._mps_isCaptureEnabled()
return torch._C._mps_isCaptureEnabled() # type: ignore[attr-defined]
def is_capturing_metal() -> bool:
"""Cheks if metal capture is in progress"""
return torch._C._mps_isCapturing()
return torch._C._mps_isCapturing() # type: ignore[attr-defined]
@contextlib.contextmanager
def metal_capture(fname: str):
"""Conext manager that enables capturing of Metal calls into gputrace"""
try:
torch._C._mps_startCapture(fname)
torch._C._mps_startCapture(fname) # type: ignore[attr-defined]
yield
# Drain all the work that were enqueued during the context call
torch.mps.synchronize()
finally:
torch._C._mps_stopCapture()
torch._C._mps_stopCapture() # type: ignore[attr-defined]

View file

@ -1,4 +1,4 @@
# mypy: ignore-errors
# mypy: allow-untyped-defs
r"""Importing this file must **not** initialize CUDA context. test_distributed
relies on this assumption to properly run. This means that when this is imported
@ -156,6 +156,7 @@ class TestEnvironment:
implied_by_fn=lambda: False,
):
enabled = default
env_var_val = None
if env_var is not None:
env_var_val = os.getenv(env_var)
enabled = enabled_fn(env_var_val, default)
@ -351,17 +352,15 @@ def get_tracked_input() -> Optional[TrackedInput]:
test_fn = extract_test_fn()
if test_fn is None:
return None
if not hasattr(test_fn, "tracked_input"):
return None
return test_fn.tracked_input
return getattr(test_fn, "tracked_input", None)
def clear_tracked_input():
def clear_tracked_input() -> None:
test_fn = extract_test_fn()
if test_fn is None:
return
if not hasattr(test_fn, "tracked_input"):
return None
test_fn.tracked_input = None
return
test_fn.tracked_input = None # type: ignore[attr-defined]
# Wraps an iterator and tracks the most recent value the iterator produces
# for debugging purposes. Tracked values are stored on the test function.
@ -428,7 +427,7 @@ class TrackedInputIter:
return
if not hasattr(self.test_fn, "tracked_input"):
return
self.test_fn.tracked_input = tracked_input
self.test_fn.tracked_input = tracked_input # type: ignore[attr-defined]
class _TestParametrizer:
"""
@ -690,7 +689,7 @@ class parametrize(_TestParametrizer):
for idx, values in enumerate(self.arg_values):
maybe_name = None
decorators = []
decorators: List[Any] = []
if isinstance(values, subtest):
sub = values
values = sub.arg_values
@ -705,7 +704,7 @@ class parametrize(_TestParametrizer):
else:
gen_test = test
values = list(values) if len(self.arg_names) > 1 else [values]
values = list(values) if len(self.arg_names) > 1 else [values] # type: ignore[call-overload]
if len(values) != len(self.arg_names):
raise RuntimeError(f'Expected # values == # arg names, but got: {len(values)} '
f'values and {len(self.arg_names)} names for test "{test.__name__}"')
@ -863,6 +862,8 @@ def cppProfilingFlagsToProfilingMode():
@contextmanager
def enable_profiling_mode_for_profiling_tests():
old_prof_exec_state = False
old_prof_mode_state = False
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
old_prof_mode_state = torch._C._get_graph_executor_optimize(True)
@ -1175,7 +1176,7 @@ def sanitize_pytest_xml(xml_file: str):
def get_pytest_test_cases(argv: List[str]) -> List[str]:
class TestCollectorPlugin:
def __init__(self) -> None:
self.tests = []
self.tests: List[Any] = []
def pytest_collection_finish(self, session):
for item in session.items:
@ -1286,6 +1287,7 @@ def run_tests(argv=UNITTEST_ARGS):
assert not failed, "Some test shards have failed"
elif USE_PYTEST:
pytest_args = argv + ["--use-main-module"]
test_report_path = ""
if TEST_SAVE_XML:
test_report_path = get_report_path(pytest=True)
print(f'Test results will be stored in {test_report_path}')
@ -1411,7 +1413,7 @@ else:
def is_privateuse1_backend_available():
privateuse1_backend_name = torch._C._get_privateuse1_backend_name()
privateuse1_backend_module = getattr(torch, privateuse1_backend_name, None)
return hasattr(privateuse1_backend_module, "is_available") and privateuse1_backend_module.is_available()
return (is_available := getattr(privateuse1_backend_module, "is_available", None)) and is_available()
IS_FILESYSTEM_UTF8_ENCODING = sys.getfilesystemencoding() == 'utf-8'
@ -1628,8 +1630,8 @@ def skipIfTorchDynamo(msg="test doesn't currently work with dynamo"):
assert isinstance(fn, type)
if TEST_WITH_TORCHDYNAMO:
fn.__unittest_skip__ = True
fn.__unittest_skip_why__ = msg
fn.__unittest_skip__ = True # type: ignore[attr-defined]
fn.__unittest_skip_why__ = msg # type: ignore[attr-defined]
return fn
@ -1649,8 +1651,8 @@ def skipIfTorchInductor(msg="test doesn't currently work with torchinductor",
assert isinstance(fn, type)
if condition:
fn.__unittest_skip__ = True
fn.__unittest_skip_why__ = msg
fn.__unittest_skip__ = True # type: ignore[attr-defined]
fn.__unittest_skip_why__ = msg # type: ignore[attr-defined]
return fn
@ -1723,8 +1725,8 @@ def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT exe
assert isinstance(fn, type)
if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
fn.__unittest_skip__ = True
fn.__unittest_skip_why__ = msg
fn.__unittest_skip__ = True # type: ignore[attr-defined]
fn.__unittest_skip_why__ = msg # type: ignore[attr-defined]
return fn
@ -1832,8 +1834,8 @@ def skipIfNNModuleInlined(
assert isinstance(fn, type)
if condition:
fn.__unittest_skip__ = True
fn.__unittest_skip_why__ = msg
fn.__unittest_skip__ = True # type: ignore[attr-defined]
fn.__unittest_skip_why__ = msg # type: ignore[attr-defined]
return fn
@ -2013,17 +2015,17 @@ class DeterministicGuard:
def __enter__(self):
self.deterministic_restore = torch.are_deterministic_algorithms_enabled()
self.warn_only_restore = torch.is_deterministic_algorithms_warn_only_enabled()
self.fill_uninitialized_memory_restore = torch.utils.deterministic.fill_uninitialized_memory
self.fill_uninitialized_memory_restore = torch.utils.deterministic.fill_uninitialized_memory # type: ignore[attr-defined]
torch.use_deterministic_algorithms(
self.deterministic,
warn_only=self.warn_only)
torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory
torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory # type: ignore[attr-defined]
def __exit__(self, exception_type, exception_value, traceback):
torch.use_deterministic_algorithms(
self.deterministic_restore,
warn_only=self.warn_only_restore)
torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory_restore
torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory_restore # type: ignore[attr-defined]
class AlwaysWarnTypedStorageRemoval:
def __init__(self, always_warn):
@ -2176,7 +2178,7 @@ def _test_function(fn, device):
def skipIfNoXNNPACK(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
if not torch.backends.xnnpack.enabled:
if not torch.backends.xnnpack.enabled: # type: ignore[attr-defined]
raise unittest.SkipTest('XNNPACK must be enabled for these tests. Please build with USE_XNNPACK=1.')
else:
fn(*args, **kwargs)
@ -2280,7 +2282,7 @@ def to_gpu(obj, type_map=None):
res.requires_grad = obj.requires_grad
return res
elif torch.is_storage(obj):
return obj.new().resize_(obj.size()).copy_(obj)
return obj.new().resize_(obj.size()).copy_(obj) # type: ignore[attr-defined, union-attr]
elif isinstance(obj, list):
return [to_gpu(o, type_map) for o in obj]
elif isinstance(obj, tuple):
@ -2482,13 +2484,13 @@ class CudaMemoryLeakCheck:
if not discrepancy_detected:
continue
if caching_allocator_discrepancy and not driver_discrepancy:
if caching_allocator_discrepancy and not driver_discrepancy: # type: ignore[possibly-undefined]
# Just raises a warning if the leak is not validated by the
# driver API
# NOTE: this may be a problem with how the caching allocator collects its
# statistics or a leak too small to trigger the allocation of an
# additional block of memory by the CUDA driver
msg = ("CUDA caching allocator reports a memory leak not "
msg = ("CUDA caching allocator reports a memory leak not " # type: ignore[possibly-undefined]
f"verified by the driver API in {self.name}! "
f"Caching allocator allocated memory was {self.caching_allocator_befores[i]} "
f"and is now reported as {caching_allocator_mem_allocated} "
@ -2498,7 +2500,7 @@ class CudaMemoryLeakCheck:
elif caching_allocator_discrepancy and driver_discrepancy:
# A caching allocator discrepancy validated by the driver API is a
# failure (except on ROCm, see below)
msg = (f"CUDA driver API confirmed a leak in {self.name}! "
msg = (f"CUDA driver API confirmed a leak in {self.name}! " # type: ignore[possibly-undefined]
f"Caching allocator allocated memory was {self.caching_allocator_befores[i]} "
f"and is now reported as {caching_allocator_mem_allocated} "
f"on device {i}. "
@ -3022,7 +3024,7 @@ class TestCase(expecttest.TestCase):
lambda repro_parts=repro_parts: print_repro_on_failure(repro_parts))
except Exception as e:
# Don't fail entirely if we can't get the test filename
log.info("could not print repro string", extra=str(e))
log.info("could not print repro string", extra=str(e)) # type: ignore[arg-type]
def assertLeaksNoCudaTensors(self, name=None):
name = self.id() if name is None else name
@ -3133,7 +3135,7 @@ class TestCase(expecttest.TestCase):
using_unittest = isinstance(result, unittest.TestResult)
super_run = super().run
test_cls = super_run.__self__
test_cls = super_run.__self__ # type: ignore[attr-defined]
# Are we compiling?
compiled = TEST_WITH_TORCHDYNAMO or TEST_WITH_AOT_EAGER or TEST_WITH_TORCHINDUCTOR
@ -3244,9 +3246,9 @@ class TestCase(expecttest.TestCase):
# Create dummy TestInfo to record results correctly
from xmlrunner.result import _TestInfo # type: ignore[import]
case = _TestInfo(result, case)
case.output = _TestInfo.ERROR
case.elapsed_time = 0.0
case.test_description = "TestSuiteEarlyFailure"
case.output = _TestInfo.ERROR # type: ignore[attr-defined]
case.elapsed_time = 0.0 # type: ignore[attr-defined]
case.test_description = "TestSuiteEarlyFailure" # type: ignore[attr-defined]
# This shouldn't really happen, but if does add fake failure
# For more details see https://github.com/pytorch/pytorch/issues/71973
result.failures.append((case, "TestSuite execution was aborted early"))
@ -3680,7 +3682,7 @@ class TestCase(expecttest.TestCase):
return get_sparse_data_with_block(pattern, blocksize)
# batch data is created recursively:
batch_data = {}
batch_data = {} # type: ignore[var-annotated]
for i, item in enumerate(pattern):
for layout, d in get_batch_sparse_data(item, blocksize).items():
target = batch_data.get(layout)
@ -3842,30 +3844,30 @@ class TestCase(expecttest.TestCase):
for blocksize in blocksizes:
for densesize in densesizes:
if layout == torch.strided:
indices = ()
indices = () # type: ignore[assignment]
values = torch.empty((basesize + densesize), device=device, dtype=dtype)
elif layout == torch.sparse_coo:
indices = (torch.empty(len(basesize), 0, device=device, dtype=index_dtype),)
indices = (torch.empty(len(basesize), 0, device=device, dtype=index_dtype),) # type: ignore[assignment]
values = torch.empty((0, *densesize), device=device, dtype=dtype)
elif layout == torch.sparse_csr:
crow_indices = torch.tensor([0] * (basesize[0] + 1), device=device, dtype=index_dtype)
col_indices = torch.empty(0, device=device, dtype=index_dtype)
indices = (crow_indices, col_indices)
indices = (crow_indices, col_indices) # type: ignore[assignment]
values = torch.empty((0, *densesize), device=device, dtype=dtype)
elif layout == torch.sparse_csc:
ccol_indices = torch.tensor([0] * (basesize[1] + 1), device=device, dtype=index_dtype)
row_indices = torch.empty(0, device=device, dtype=index_dtype)
indices = (ccol_indices, row_indices)
indices = (ccol_indices, row_indices) # type: ignore[assignment]
values = torch.empty((0, *densesize), device=device, dtype=dtype)
elif layout == torch.sparse_bsr:
crow_indices = torch.tensor([0] * (basesize[0] // blocksize[0] + 1), device=device, dtype=index_dtype)
col_indices = torch.empty(0, device=device, dtype=index_dtype)
indices = (crow_indices, col_indices)
indices = (crow_indices, col_indices) # type: ignore[assignment]
values = torch.empty((0, *blocksize, *densesize), device=device, dtype=dtype)
elif layout == torch.sparse_bsc:
ccol_indices = torch.tensor([0] * (basesize[1] // blocksize[1] + 1), device=device, dtype=index_dtype)
row_indices = torch.empty(0, device=device, dtype=index_dtype)
indices = (ccol_indices, row_indices)
indices = (ccol_indices, row_indices) # type: ignore[assignment]
values = torch.empty((0, *blocksize, *densesize), device=device, dtype=dtype)
else:
assert 0 # unreachable
@ -4029,9 +4031,9 @@ class TestCase(expecttest.TestCase):
if error_metas:
# See [ErrorMeta Cycles]
error_metas = [error_metas]
error_metas = [error_metas] # type: ignore[list-item]
# TODO: compose all metas into one AssertionError
raise error_metas.pop()[0].to_error(
raise error_metas.pop()[0].to_error( # type: ignore[index]
# This emulates unittest.TestCase's behavior if a custom message passed and
# TestCase.longMessage (https://docs.python.org/3/library/unittest.html#unittest.TestCase.longMessage)
# is True (default)
@ -4062,7 +4064,7 @@ class TestCase(expecttest.TestCase):
context: Optional[AssertRaisesContextIgnoreNotImplementedError] = \
AssertRaisesContextIgnoreNotImplementedError(expected_exception, self) # type: ignore[call-arg]
try:
return context.handle('assertRaises', args, kwargs) # type: ignore[union-attr]
return context.handle('assertRaises', args, kwargs) # type: ignore[union-attr, arg-type]
finally:
# see https://bugs.python.org/issue23890
context = None
@ -4087,7 +4089,7 @@ class TestCase(expecttest.TestCase):
if self._ignore_not_implemented_error:
context = AssertRaisesContextIgnoreNotImplementedError( # type: ignore[call-arg]
expected_exception, self, expected_regex)
return context.handle('assertRaisesRegex', args, kwargs) # type: ignore[attr-defined]
return context.handle('assertRaisesRegex', args, kwargs) # type: ignore[attr-defined, arg-type]
else:
return super().assertRaisesRegex(expected_exception, expected_regex, *args, **kwargs)
@ -4185,8 +4187,8 @@ class TestCase(expecttest.TestCase):
# test/common_utils.py, but it matters in onnx-pytorch
module_id = self.__class__.__module__
munged_id = remove_prefix(self.id(), module_id + ".")
test_file = os.path.realpath(sys.modules[module_id].__file__)
expected_file = os.path.join(os.path.dirname(test_file),
test_file = os.path.realpath(sys.modules[module_id].__file__) # type: ignore[type-var]
expected_file = os.path.join(os.path.dirname(test_file), # type: ignore[type-var, arg-type]
"expect",
munged_id)
@ -4284,7 +4286,7 @@ class TestCase(expecttest.TestCase):
if attrs.get("operator") == operator:
break
self.assertEqual(attrs["operator"], operator)
self.assertEqual(attrs["operator"], operator) # type: ignore[possibly-undefined]
self.assertEqual(attrs.get("overload_name", ""), overload_name)
def check_nondeterministic_alert(self, fn, caller_name, should_alert=True):
@ -4702,7 +4704,7 @@ def random_lowrank_matrix(rank, rows, columns, *batch_dims, **kwargs):
def _generate_indices_prefer_all_rows(rows: int, cols: int, num_indices: int) -> torch.Tensor:
"""Generate indices for a row x cols matrix, preferring at least one index per row if possible."""
indices = []
indices = [] # type: ignore[var-annotated]
n_per_row = math.ceil(num_indices / rows)
col_indices = list(range(cols))
@ -4854,7 +4856,7 @@ def do_test_empty_full(self, dtypes, layout, device):
int64_dtype, layout, device, fv + 5, False)
# FIXME: improve load_tests() documentation here
running_script_path = None
running_script_path = None # type: ignore[var-annotated]
def set_running_script_path():
global running_script_path
try:
@ -5116,7 +5118,7 @@ def copy_func(f):
argdefs=f.__defaults__,
closure=f.__closure__)
g = functools.update_wrapper(g, f)
g.__kwdefaults__ = f.__kwdefaults__
g.__kwdefaults__ = f.__kwdefaults__ # type: ignore[attr-defined]
return g
@ -5304,8 +5306,8 @@ class TestGradients(TestCase):
if is_iterable_of_tensors(sample.input):
all_args = chain(sample.input, sample.args, sample.kwargs.values())
else:
all_args = tuple(chain((sample.input,), sample.args, sample.kwargs.values()))
gradcheck_args = tuple(x for x in all_args if (isinstance(x, torch.Tensor) and x.requires_grad))
all_args = tuple(chain((sample.input,), sample.args, sample.kwargs.values())) # type: ignore[assignment]
gradcheck_args = tuple(x for x in all_args if (isinstance(x, torch.Tensor) and x.requires_grad)) # type: ignore[union-attr]
# Verifies sample input tensors should have no grad
# This may happen if the same tensor is used in two different SampleInputs
@ -5525,7 +5527,7 @@ def check_leaked_tensors(limit=1, matched_type=torch.Tensor):
try:
gc.collect()
gc.set_debug(gc.DEBUG_SAVEALL)
garbage_objs = []
garbage_objs = [] # type: ignore[var-annotated]
# run the user code, after cleaning any existing refcycles, and then check for new ones
# also allow usercode to check the garbage objs (e.g. for assertion) after exiting ctxmgr
@ -5539,7 +5541,7 @@ def check_leaked_tensors(limit=1, matched_type=torch.Tensor):
f"{num_garbage_objs} tensors were found in the garbage. Did you introduce a reference cycle?"
)
try:
import objgraph
import objgraph # type: ignore[import-not-found]
warnings.warn(
f"Dumping first {limit} objgraphs of leaked {matched_type}s rendered to png"
)