diff --git a/test/conftest.py b/test/conftest.py index 60288387e3d..e02f24ad9cb 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -135,8 +135,7 @@ class _NodeReporterReruns(_NodeReporter): else: assert isinstance(report.longrepr, tuple) filename, lineno, skipreason = report.longrepr - if skipreason.startswith("Skipped: "): - skipreason = skipreason[9:] + skipreason = skipreason.removeprefix("Skipped: ") details = f"{filename}:{lineno}: {skipreason}" skipped = ET.Element( diff --git a/test/dynamo/test_python_autograd.py b/test/dynamo/test_python_autograd.py index f4edb188c58..2acaf67add6 100644 --- a/test/dynamo/test_python_autograd.py +++ b/test/dynamo/test_python_autograd.py @@ -1,5 +1,5 @@ # Owner(s): ["module: dynamo"] -from typing import Callable, List, NamedTuple, Optional +from typing import Callable, NamedTuple, Optional import torch import torch._dynamo @@ -81,7 +81,7 @@ def grad(L, desired_results: list[Variable]) -> list[Variable]: # look up dL_dentries. If a variable is never used to compute the loss, # we consider its gradient None, see the note below about zeros for more information. - def gather_grad(entries: List[str]): + def gather_grad(entries: list[str]): return [dL_d[entry] if entry in dL_d else None for entry in entries] # propagate the gradient information backward @@ -127,7 +127,7 @@ def operator_mul(self: Variable, rhs: Variable) -> Variable: outputs = [r.name] # define backprop - def propagate(dL_doutputs: List[Variable]): + def propagate(dL_doutputs: list[Variable]): (dL_dr,) = dL_doutputs dr_dself = rhs # partial derivative of r = self*rhs @@ -150,7 +150,7 @@ def operator_add(self: Variable, rhs: Variable) -> Variable: r = Variable(self.value + rhs.value) # print(f'{r.name} = {self.name} + {rhs.name}') - def propagate(dL_doutputs: List[Variable]): + def propagate(dL_doutputs: list[Variable]): (dL_dr,) = dL_doutputs dr_dself = 1.0 dr_drhs = 1.0 @@ -168,7 +168,7 @@ def operator_sum(self: Variable, name: Optional[str]) -> "Variable": r = Variable(torch.sum(self.value), name=name) # print(f'{r.name} = {self.name}.sum()') - def propagate(dL_doutputs: List[Variable]): + def propagate(dL_doutputs: list[Variable]): (dL_dr,) = dL_doutputs size = self.value.size() return [dL_dr.expand(*size)] @@ -179,12 +179,12 @@ def operator_sum(self: Variable, name: Optional[str]) -> "Variable": return r -def operator_expand(self: Variable, sizes: List[int]) -> "Variable": +def operator_expand(self: Variable, sizes: list[int]) -> "Variable": assert self.value.dim() == 0 # only works for scalars r = Variable(self.value.expand(sizes)) # print(f'{r.name} = {self.name}.expand({sizes})') - def propagate(dL_doutputs: List[Variable]): + def propagate(dL_doutputs: list[Variable]): (dL_dr,) = dL_doutputs return [dL_dr.sum()] diff --git a/test/functorch/discover_coverage.py b/test/functorch/discover_coverage.py index 92ea79b8008..6d9d9e7e8a7 100644 --- a/test/functorch/discover_coverage.py +++ b/test/functorch/discover_coverage.py @@ -44,7 +44,7 @@ def get_public_overridable_apis(pytorch_root="/raid/rzou/pt/debug-cpu"): if line.startswith(".. autofunction::") ] lines = api_lines1 + api_lines2 - lines = [line[7:] if line.startswith("Tensor.") else line for line in lines] + lines = [line.removeprefix("Tensor.") for line in lines] lines = [line for line in lines if hasattr(module, line)] for line in lines: api = getattr(module, line) diff --git a/test/onnx/internal/test_diagnostics.py b/test/onnx/internal/test_diagnostics.py index fc2ad3a426b..a7d979ff96d 100644 --- a/test/onnx/internal/test_diagnostics.py +++ b/test/onnx/internal/test_diagnostics.py @@ -6,7 +6,7 @@ import dataclasses import io import logging import typing -from typing import AbstractSet, Protocol +from typing import Protocol import torch from torch.onnx import errors @@ -27,7 +27,7 @@ class _SarifLogBuilder(Protocol): def _assert_has_diagnostics( sarif_log_builder: _SarifLogBuilder, - rule_level_pairs: AbstractSet[tuple[infra.Rule, infra.Level]], + rule_level_pairs: set[tuple[infra.Rule, infra.Level]], ): sarif_log = sarif_log_builder.sarif_log() unseen_pairs = {(rule.id, level.name.lower()) for rule, level in rule_level_pairs} @@ -62,7 +62,7 @@ class _RuleCollectionForTest(infra.RuleCollection): def assert_all_diagnostics( test_suite: unittest.TestCase, sarif_log_builder: _SarifLogBuilder, - rule_level_pairs: AbstractSet[tuple[infra.Rule, infra.Level]], + rule_level_pairs: set[tuple[infra.Rule, infra.Level]], ): """Context manager to assert that all diagnostics are emitted. diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 3f9982763ca..24a8a128d8f 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -12,7 +12,7 @@ import tempfile import typing import unittest from types import BuiltinFunctionType -from typing import Callable, NamedTuple, Optional, Union, List +from typing import Callable, List, NamedTuple, Optional, Union import torch import torch.fx.experimental.meta_tracer diff --git a/test/test_overrides.py b/test/test_overrides.py index 25ca3651157..c2e08df2280 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -706,8 +706,7 @@ def generate_tensor_like_override_tests(cls): for arg in annotated_args[func]: # Guess valid input to aten function based on type of argument t = arg["simple_type"] - if t.endswith("?"): - t = t[:-1] + t = t.removesuffix("?") if t == "Tensor" and is_method and arg["name"] == "self": # See "Note: properties and __get__" func = func.__get__(instance_gen())