diff --git a/torch/_functorch/_activation_checkpointing/graph_info_provider.py b/torch/_functorch/_activation_checkpointing/graph_info_provider.py index 75cbda5264d..581612921f0 100644 --- a/torch/_functorch/_activation_checkpointing/graph_info_provider.py +++ b/torch/_functorch/_activation_checkpointing/graph_info_provider.py @@ -276,7 +276,7 @@ class GraphInfoProvider: vmin=min(self.get_knapsack_memory_input()), vmax=max(self.get_knapsack_memory_input()), ) - cmap = cm.viridis + cmap = cm.viridis # type: ignore[attr-defined] # Assign colors based on memory node_colors = [ diff --git a/torch/mps/profiler.py b/torch/mps/profiler.py index ac94645e195..6e194bb63b2 100644 --- a/torch/mps/profiler.py +++ b/torch/mps/profiler.py @@ -72,21 +72,21 @@ def is_metal_capture_enabled() -> bool: """Checks if `metal_capture` context manager is usable To enable metal capture, set MTL_CAPTURE_ENABLED envvar """ - return torch._C._mps_isCaptureEnabled() + return torch._C._mps_isCaptureEnabled() # type: ignore[attr-defined] def is_capturing_metal() -> bool: """Cheks if metal capture is in progress""" - return torch._C._mps_isCapturing() + return torch._C._mps_isCapturing() # type: ignore[attr-defined] @contextlib.contextmanager def metal_capture(fname: str): """Conext manager that enables capturing of Metal calls into gputrace""" try: - torch._C._mps_startCapture(fname) + torch._C._mps_startCapture(fname) # type: ignore[attr-defined] yield # Drain all the work that were enqueued during the context call torch.mps.synchronize() finally: - torch._C._mps_stopCapture() + torch._C._mps_stopCapture() # type: ignore[attr-defined] diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index eeedea8f129..1dfc48e8b01 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1,4 +1,4 @@ -# mypy: ignore-errors +# mypy: allow-untyped-defs r"""Importing this file must **not** initialize CUDA context. test_distributed relies on this assumption to properly run. This means that when this is imported @@ -156,6 +156,7 @@ class TestEnvironment: implied_by_fn=lambda: False, ): enabled = default + env_var_val = None if env_var is not None: env_var_val = os.getenv(env_var) enabled = enabled_fn(env_var_val, default) @@ -351,17 +352,15 @@ def get_tracked_input() -> Optional[TrackedInput]: test_fn = extract_test_fn() if test_fn is None: return None - if not hasattr(test_fn, "tracked_input"): - return None - return test_fn.tracked_input + return getattr(test_fn, "tracked_input", None) -def clear_tracked_input(): +def clear_tracked_input() -> None: test_fn = extract_test_fn() if test_fn is None: return if not hasattr(test_fn, "tracked_input"): - return None - test_fn.tracked_input = None + return + test_fn.tracked_input = None # type: ignore[attr-defined] # Wraps an iterator and tracks the most recent value the iterator produces # for debugging purposes. Tracked values are stored on the test function. @@ -428,7 +427,7 @@ class TrackedInputIter: return if not hasattr(self.test_fn, "tracked_input"): return - self.test_fn.tracked_input = tracked_input + self.test_fn.tracked_input = tracked_input # type: ignore[attr-defined] class _TestParametrizer: """ @@ -690,7 +689,7 @@ class parametrize(_TestParametrizer): for idx, values in enumerate(self.arg_values): maybe_name = None - decorators = [] + decorators: List[Any] = [] if isinstance(values, subtest): sub = values values = sub.arg_values @@ -705,7 +704,7 @@ class parametrize(_TestParametrizer): else: gen_test = test - values = list(values) if len(self.arg_names) > 1 else [values] + values = list(values) if len(self.arg_names) > 1 else [values] # type: ignore[call-overload] if len(values) != len(self.arg_names): raise RuntimeError(f'Expected # values == # arg names, but got: {len(values)} ' f'values and {len(self.arg_names)} names for test "{test.__name__}"') @@ -863,6 +862,8 @@ def cppProfilingFlagsToProfilingMode(): @contextmanager def enable_profiling_mode_for_profiling_tests(): + old_prof_exec_state = False + old_prof_mode_state = False if GRAPH_EXECUTOR == ProfilingMode.PROFILING: old_prof_exec_state = torch._C._jit_set_profiling_executor(True) old_prof_mode_state = torch._C._get_graph_executor_optimize(True) @@ -1175,7 +1176,7 @@ def sanitize_pytest_xml(xml_file: str): def get_pytest_test_cases(argv: List[str]) -> List[str]: class TestCollectorPlugin: def __init__(self) -> None: - self.tests = [] + self.tests: List[Any] = [] def pytest_collection_finish(self, session): for item in session.items: @@ -1286,6 +1287,7 @@ def run_tests(argv=UNITTEST_ARGS): assert not failed, "Some test shards have failed" elif USE_PYTEST: pytest_args = argv + ["--use-main-module"] + test_report_path = "" if TEST_SAVE_XML: test_report_path = get_report_path(pytest=True) print(f'Test results will be stored in {test_report_path}') @@ -1411,7 +1413,7 @@ else: def is_privateuse1_backend_available(): privateuse1_backend_name = torch._C._get_privateuse1_backend_name() privateuse1_backend_module = getattr(torch, privateuse1_backend_name, None) - return hasattr(privateuse1_backend_module, "is_available") and privateuse1_backend_module.is_available() + return (is_available := getattr(privateuse1_backend_module, "is_available", None)) and is_available() IS_FILESYSTEM_UTF8_ENCODING = sys.getfilesystemencoding() == 'utf-8' @@ -1628,8 +1630,8 @@ def skipIfTorchDynamo(msg="test doesn't currently work with dynamo"): assert isinstance(fn, type) if TEST_WITH_TORCHDYNAMO: - fn.__unittest_skip__ = True - fn.__unittest_skip_why__ = msg + fn.__unittest_skip__ = True # type: ignore[attr-defined] + fn.__unittest_skip_why__ = msg # type: ignore[attr-defined] return fn @@ -1649,8 +1651,8 @@ def skipIfTorchInductor(msg="test doesn't currently work with torchinductor", assert isinstance(fn, type) if condition: - fn.__unittest_skip__ = True - fn.__unittest_skip_why__ = msg + fn.__unittest_skip__ = True # type: ignore[attr-defined] + fn.__unittest_skip_why__ = msg # type: ignore[attr-defined] return fn @@ -1723,8 +1725,8 @@ def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT exe assert isinstance(fn, type) if GRAPH_EXECUTOR == ProfilingMode.LEGACY: - fn.__unittest_skip__ = True - fn.__unittest_skip_why__ = msg + fn.__unittest_skip__ = True # type: ignore[attr-defined] + fn.__unittest_skip_why__ = msg # type: ignore[attr-defined] return fn @@ -1832,8 +1834,8 @@ def skipIfNNModuleInlined( assert isinstance(fn, type) if condition: - fn.__unittest_skip__ = True - fn.__unittest_skip_why__ = msg + fn.__unittest_skip__ = True # type: ignore[attr-defined] + fn.__unittest_skip_why__ = msg # type: ignore[attr-defined] return fn @@ -2013,17 +2015,17 @@ class DeterministicGuard: def __enter__(self): self.deterministic_restore = torch.are_deterministic_algorithms_enabled() self.warn_only_restore = torch.is_deterministic_algorithms_warn_only_enabled() - self.fill_uninitialized_memory_restore = torch.utils.deterministic.fill_uninitialized_memory + self.fill_uninitialized_memory_restore = torch.utils.deterministic.fill_uninitialized_memory # type: ignore[attr-defined] torch.use_deterministic_algorithms( self.deterministic, warn_only=self.warn_only) - torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory + torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory # type: ignore[attr-defined] def __exit__(self, exception_type, exception_value, traceback): torch.use_deterministic_algorithms( self.deterministic_restore, warn_only=self.warn_only_restore) - torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory_restore + torch.utils.deterministic.fill_uninitialized_memory = self.fill_uninitialized_memory_restore # type: ignore[attr-defined] class AlwaysWarnTypedStorageRemoval: def __init__(self, always_warn): @@ -2176,7 +2178,7 @@ def _test_function(fn, device): def skipIfNoXNNPACK(fn): @wraps(fn) def wrapper(*args, **kwargs): - if not torch.backends.xnnpack.enabled: + if not torch.backends.xnnpack.enabled: # type: ignore[attr-defined] raise unittest.SkipTest('XNNPACK must be enabled for these tests. Please build with USE_XNNPACK=1.') else: fn(*args, **kwargs) @@ -2280,7 +2282,7 @@ def to_gpu(obj, type_map=None): res.requires_grad = obj.requires_grad return res elif torch.is_storage(obj): - return obj.new().resize_(obj.size()).copy_(obj) + return obj.new().resize_(obj.size()).copy_(obj) # type: ignore[attr-defined, union-attr] elif isinstance(obj, list): return [to_gpu(o, type_map) for o in obj] elif isinstance(obj, tuple): @@ -2482,13 +2484,13 @@ class CudaMemoryLeakCheck: if not discrepancy_detected: continue - if caching_allocator_discrepancy and not driver_discrepancy: + if caching_allocator_discrepancy and not driver_discrepancy: # type: ignore[possibly-undefined] # Just raises a warning if the leak is not validated by the # driver API # NOTE: this may be a problem with how the caching allocator collects its # statistics or a leak too small to trigger the allocation of an # additional block of memory by the CUDA driver - msg = ("CUDA caching allocator reports a memory leak not " + msg = ("CUDA caching allocator reports a memory leak not " # type: ignore[possibly-undefined] f"verified by the driver API in {self.name}! " f"Caching allocator allocated memory was {self.caching_allocator_befores[i]} " f"and is now reported as {caching_allocator_mem_allocated} " @@ -2498,7 +2500,7 @@ class CudaMemoryLeakCheck: elif caching_allocator_discrepancy and driver_discrepancy: # A caching allocator discrepancy validated by the driver API is a # failure (except on ROCm, see below) - msg = (f"CUDA driver API confirmed a leak in {self.name}! " + msg = (f"CUDA driver API confirmed a leak in {self.name}! " # type: ignore[possibly-undefined] f"Caching allocator allocated memory was {self.caching_allocator_befores[i]} " f"and is now reported as {caching_allocator_mem_allocated} " f"on device {i}. " @@ -3022,7 +3024,7 @@ class TestCase(expecttest.TestCase): lambda repro_parts=repro_parts: print_repro_on_failure(repro_parts)) except Exception as e: # Don't fail entirely if we can't get the test filename - log.info("could not print repro string", extra=str(e)) + log.info("could not print repro string", extra=str(e)) # type: ignore[arg-type] def assertLeaksNoCudaTensors(self, name=None): name = self.id() if name is None else name @@ -3133,7 +3135,7 @@ class TestCase(expecttest.TestCase): using_unittest = isinstance(result, unittest.TestResult) super_run = super().run - test_cls = super_run.__self__ + test_cls = super_run.__self__ # type: ignore[attr-defined] # Are we compiling? compiled = TEST_WITH_TORCHDYNAMO or TEST_WITH_AOT_EAGER or TEST_WITH_TORCHINDUCTOR @@ -3244,9 +3246,9 @@ class TestCase(expecttest.TestCase): # Create dummy TestInfo to record results correctly from xmlrunner.result import _TestInfo # type: ignore[import] case = _TestInfo(result, case) - case.output = _TestInfo.ERROR - case.elapsed_time = 0.0 - case.test_description = "TestSuiteEarlyFailure" + case.output = _TestInfo.ERROR # type: ignore[attr-defined] + case.elapsed_time = 0.0 # type: ignore[attr-defined] + case.test_description = "TestSuiteEarlyFailure" # type: ignore[attr-defined] # This shouldn't really happen, but if does add fake failure # For more details see https://github.com/pytorch/pytorch/issues/71973 result.failures.append((case, "TestSuite execution was aborted early")) @@ -3680,7 +3682,7 @@ class TestCase(expecttest.TestCase): return get_sparse_data_with_block(pattern, blocksize) # batch data is created recursively: - batch_data = {} + batch_data = {} # type: ignore[var-annotated] for i, item in enumerate(pattern): for layout, d in get_batch_sparse_data(item, blocksize).items(): target = batch_data.get(layout) @@ -3842,30 +3844,30 @@ class TestCase(expecttest.TestCase): for blocksize in blocksizes: for densesize in densesizes: if layout == torch.strided: - indices = () + indices = () # type: ignore[assignment] values = torch.empty((basesize + densesize), device=device, dtype=dtype) elif layout == torch.sparse_coo: - indices = (torch.empty(len(basesize), 0, device=device, dtype=index_dtype),) + indices = (torch.empty(len(basesize), 0, device=device, dtype=index_dtype),) # type: ignore[assignment] values = torch.empty((0, *densesize), device=device, dtype=dtype) elif layout == torch.sparse_csr: crow_indices = torch.tensor([0] * (basesize[0] + 1), device=device, dtype=index_dtype) col_indices = torch.empty(0, device=device, dtype=index_dtype) - indices = (crow_indices, col_indices) + indices = (crow_indices, col_indices) # type: ignore[assignment] values = torch.empty((0, *densesize), device=device, dtype=dtype) elif layout == torch.sparse_csc: ccol_indices = torch.tensor([0] * (basesize[1] + 1), device=device, dtype=index_dtype) row_indices = torch.empty(0, device=device, dtype=index_dtype) - indices = (ccol_indices, row_indices) + indices = (ccol_indices, row_indices) # type: ignore[assignment] values = torch.empty((0, *densesize), device=device, dtype=dtype) elif layout == torch.sparse_bsr: crow_indices = torch.tensor([0] * (basesize[0] // blocksize[0] + 1), device=device, dtype=index_dtype) col_indices = torch.empty(0, device=device, dtype=index_dtype) - indices = (crow_indices, col_indices) + indices = (crow_indices, col_indices) # type: ignore[assignment] values = torch.empty((0, *blocksize, *densesize), device=device, dtype=dtype) elif layout == torch.sparse_bsc: ccol_indices = torch.tensor([0] * (basesize[1] // blocksize[1] + 1), device=device, dtype=index_dtype) row_indices = torch.empty(0, device=device, dtype=index_dtype) - indices = (ccol_indices, row_indices) + indices = (ccol_indices, row_indices) # type: ignore[assignment] values = torch.empty((0, *blocksize, *densesize), device=device, dtype=dtype) else: assert 0 # unreachable @@ -4029,9 +4031,9 @@ class TestCase(expecttest.TestCase): if error_metas: # See [ErrorMeta Cycles] - error_metas = [error_metas] + error_metas = [error_metas] # type: ignore[list-item] # TODO: compose all metas into one AssertionError - raise error_metas.pop()[0].to_error( + raise error_metas.pop()[0].to_error( # type: ignore[index] # This emulates unittest.TestCase's behavior if a custom message passed and # TestCase.longMessage (https://docs.python.org/3/library/unittest.html#unittest.TestCase.longMessage) # is True (default) @@ -4062,7 +4064,7 @@ class TestCase(expecttest.TestCase): context: Optional[AssertRaisesContextIgnoreNotImplementedError] = \ AssertRaisesContextIgnoreNotImplementedError(expected_exception, self) # type: ignore[call-arg] try: - return context.handle('assertRaises', args, kwargs) # type: ignore[union-attr] + return context.handle('assertRaises', args, kwargs) # type: ignore[union-attr, arg-type] finally: # see https://bugs.python.org/issue23890 context = None @@ -4087,7 +4089,7 @@ class TestCase(expecttest.TestCase): if self._ignore_not_implemented_error: context = AssertRaisesContextIgnoreNotImplementedError( # type: ignore[call-arg] expected_exception, self, expected_regex) - return context.handle('assertRaisesRegex', args, kwargs) # type: ignore[attr-defined] + return context.handle('assertRaisesRegex', args, kwargs) # type: ignore[attr-defined, arg-type] else: return super().assertRaisesRegex(expected_exception, expected_regex, *args, **kwargs) @@ -4185,8 +4187,8 @@ class TestCase(expecttest.TestCase): # test/common_utils.py, but it matters in onnx-pytorch module_id = self.__class__.__module__ munged_id = remove_prefix(self.id(), module_id + ".") - test_file = os.path.realpath(sys.modules[module_id].__file__) - expected_file = os.path.join(os.path.dirname(test_file), + test_file = os.path.realpath(sys.modules[module_id].__file__) # type: ignore[type-var] + expected_file = os.path.join(os.path.dirname(test_file), # type: ignore[type-var, arg-type] "expect", munged_id) @@ -4284,7 +4286,7 @@ class TestCase(expecttest.TestCase): if attrs.get("operator") == operator: break - self.assertEqual(attrs["operator"], operator) + self.assertEqual(attrs["operator"], operator) # type: ignore[possibly-undefined] self.assertEqual(attrs.get("overload_name", ""), overload_name) def check_nondeterministic_alert(self, fn, caller_name, should_alert=True): @@ -4702,7 +4704,7 @@ def random_lowrank_matrix(rank, rows, columns, *batch_dims, **kwargs): def _generate_indices_prefer_all_rows(rows: int, cols: int, num_indices: int) -> torch.Tensor: """Generate indices for a row x cols matrix, preferring at least one index per row if possible.""" - indices = [] + indices = [] # type: ignore[var-annotated] n_per_row = math.ceil(num_indices / rows) col_indices = list(range(cols)) @@ -4854,7 +4856,7 @@ def do_test_empty_full(self, dtypes, layout, device): int64_dtype, layout, device, fv + 5, False) # FIXME: improve load_tests() documentation here -running_script_path = None +running_script_path = None # type: ignore[var-annotated] def set_running_script_path(): global running_script_path try: @@ -5116,7 +5118,7 @@ def copy_func(f): argdefs=f.__defaults__, closure=f.__closure__) g = functools.update_wrapper(g, f) - g.__kwdefaults__ = f.__kwdefaults__ + g.__kwdefaults__ = f.__kwdefaults__ # type: ignore[attr-defined] return g @@ -5304,8 +5306,8 @@ class TestGradients(TestCase): if is_iterable_of_tensors(sample.input): all_args = chain(sample.input, sample.args, sample.kwargs.values()) else: - all_args = tuple(chain((sample.input,), sample.args, sample.kwargs.values())) - gradcheck_args = tuple(x for x in all_args if (isinstance(x, torch.Tensor) and x.requires_grad)) + all_args = tuple(chain((sample.input,), sample.args, sample.kwargs.values())) # type: ignore[assignment] + gradcheck_args = tuple(x for x in all_args if (isinstance(x, torch.Tensor) and x.requires_grad)) # type: ignore[union-attr] # Verifies sample input tensors should have no grad # This may happen if the same tensor is used in two different SampleInputs @@ -5525,7 +5527,7 @@ def check_leaked_tensors(limit=1, matched_type=torch.Tensor): try: gc.collect() gc.set_debug(gc.DEBUG_SAVEALL) - garbage_objs = [] + garbage_objs = [] # type: ignore[var-annotated] # run the user code, after cleaning any existing refcycles, and then check for new ones # also allow usercode to check the garbage objs (e.g. for assertion) after exiting ctxmgr @@ -5539,7 +5541,7 @@ def check_leaked_tensors(limit=1, matched_type=torch.Tensor): f"{num_garbage_objs} tensors were found in the garbage. Did you introduce a reference cycle?" ) try: - import objgraph + import objgraph # type: ignore[import-not-found] warnings.warn( f"Dumping first {limit} objgraphs of leaked {matched_type}s rendered to png" )