[BE][PYFMT] bump ruff format target version to py39: add parentheses around long with-statements

ghstack-source-id: 0563dca144
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145148
This commit is contained in:
Xuehai Pan 2025-02-10 22:00:03 +08:00
parent 50e4aa2549
commit 57fcc48993
14 changed files with 137 additions and 66 deletions

View file

@ -916,10 +916,13 @@ def latency_experiment(args, model_iter_fn, model, example_inputs, mark, **kwarg
# inputs will incur high penalty then the next one.
maybe_mark_step(args)
with maybe_mark_profile(p=p, mark=mark), maybe_enable_compiled_autograd(
args.compiled_autograd,
fullgraph=args.nopython,
dynamic=args.dynamic_shapes,
with (
maybe_mark_profile(p=p, mark=mark),
maybe_enable_compiled_autograd(
args.compiled_autograd,
fullgraph=args.nopython,
dynamic=args.dynamic_shapes,
),
):
timings[rep], actual_output = timed(
model,
@ -1090,10 +1093,13 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
# call mark_step between the 2 calls to make the comparison fair.
maybe_mark_step(args)
with maybe_mark_profile(p=p, mark="actual"), maybe_enable_compiled_autograd(
args.compiled_autograd,
fullgraph=args.nopython,
dynamic=args.dynamic_shapes,
with (
maybe_mark_profile(p=p, mark="actual"),
maybe_enable_compiled_autograd(
args.compiled_autograd,
fullgraph=args.nopython,
dynamic=args.dynamic_shapes,
),
):
timings[rep, 1], actual_output = timed(
model,
@ -2445,12 +2451,15 @@ class BenchmarkRunner:
else:
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
with maybe_enable_compiled_autograd(
self.args.compiled_autograd,
fullgraph=self.args.nopython,
dynamic=self.args.dynamic_shapes,
), maybe_snapshot_memory(
self.args.snapshot_memory, f"compiled_{self.args.only}"
with (
maybe_enable_compiled_autograd(
self.args.compiled_autograd,
fullgraph=self.args.nopython,
dynamic=self.args.dynamic_shapes,
),
maybe_snapshot_memory(
self.args.snapshot_memory, f"compiled_{self.args.only}"
),
):
dynamo_latency, dynamo_peak_mem, dynamo_stats = warmup(
optimized_model_iter_fn, model, example_inputs, "dynamo"
@ -2598,12 +2607,15 @@ class BenchmarkRunner:
else:
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
with maybe_enable_compiled_autograd(
self.args.compiled_autograd,
fullgraph=self.args.nopython,
dynamic=self.args.dynamic_shapes,
), maybe_snapshot_memory(
self.args.snapshot_memory, f"compiled_{self.args.only}"
with (
maybe_enable_compiled_autograd(
self.args.compiled_autograd,
fullgraph=self.args.nopython,
dynamic=self.args.dynamic_shapes,
),
maybe_snapshot_memory(
self.args.snapshot_memory, f"compiled_{self.args.only}"
),
):
dynamo_latency, dynamo_peak_mem, dynamo_stats = warmup(
optimized_model_iter_fn, model, example_inputs, "dynamo"

View file

@ -56,9 +56,11 @@ class Benchmark(BenchmarkBase):
def _work(self):
# enable_cpp_symbolic_shape_guards has impact on this benchmark
# Keep using False value for consistency.
with fresh_inductor_cache(), torch._inductor.config.patch(
force_shape_pad=self._force_shape_pad
), torch._dynamo.config.patch("enable_cpp_symbolic_shape_guards", False):
with (
fresh_inductor_cache(),
torch._inductor.config.patch(force_shape_pad=self._force_shape_pad),
torch._dynamo.config.patch("enable_cpp_symbolic_shape_guards", False),
):
opt_m = torch.compile(backend=self.backend(), dynamic=self.is_dynamic())(
self.m.cuda() if self._is_gpu else self.m
)

View file

@ -40,6 +40,7 @@ standard_library = ["typing_extensions"]
[tool.ruff]
# TODO: remove manual set `target-version` in tools/linter/adapters/pyfmt_linter.py when we move to py39
target-version = "py38"
line-length = 88
src = ["caffe2", "torch", "torchgen", "functorch", "test"]

View file

@ -513,8 +513,9 @@ def verify(
"could mean that your network is numerically unstable. Otherwise\n"
"it indicates a bug in PyTorch/ONNX; please file a bug report."
)
with Errors(msg, rtol=rtol, atol=atol) as errs, errs.addErrCtxt(
result_hint
with (
Errors(msg, rtol=rtol, atol=atol) as errs,
errs.addErrCtxt(result_hint),
):
for i, (x, y) in enumerate(zip(torch_out, backend_out)):
errs.checkAlmostEqual(x.data.cpu().numpy(), y, f"In output {i}")

View file

@ -580,9 +580,16 @@ class TestDecomp(TestCase):
args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs
func = partial(op.get_op(), **kwargs)
with self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all=False
) as mode, enable_python_dispatcher():
with (
self.DecompCrossRefMode(
self,
self.precision,
self.rel_tol,
dtype,
run_all=False,
) as mode,
enable_python_dispatcher(),
):
torch.autograd.gradcheck(func, args)
self.check_decomposed(aten_name, mode)
@ -677,9 +684,16 @@ class TestDecomp(TestCase):
module_input.forward_input.args,
module_input.forward_input.kwargs,
)
with self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all=True
), enable_python_dispatcher():
with (
self.DecompCrossRefMode(
self,
self.precision,
self.rel_tol,
dtype,
run_all=True,
),
enable_python_dispatcher(),
):
decomp_out = m(*args, **kwargs)
non_decomp_out = m(*args, **kwargs)
@ -955,9 +969,16 @@ def forward(self, scores_1, mask_1, value_1):
# store the called list on the mode object instance and no
# explicit clearing is necessary as I will create a fresh mode
# for each region
with self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all
) as mode, enable_python_dispatcher():
with (
self.DecompCrossRefMode(
self,
self.precision,
self.rel_tol,
dtype,
run_all,
) as mode,
enable_python_dispatcher(),
):
decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
if run_without_python_dispatcher(mode):
# without this check, incorrect decomps at the python dispatcher level can still pass because
@ -974,9 +995,16 @@ def forward(self, scores_1, mask_1, value_1):
):
cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out)
with self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all
) as mode, enable_python_dispatcher():
with (
self.DecompCrossRefMode(
self,
self.precision,
self.rel_tol,
dtype,
run_all,
) as mode,
enable_python_dispatcher(),
):
decomp_vjp_fn(cotangents)
if run_without_python_dispatcher(mode):
# without this check, incorrect decomps at the python dispatcher level can still pass because
@ -993,9 +1021,16 @@ def forward(self, scores_1, mask_1, value_1):
kwargs = sample_input.kwargs
# A failure here might be because the decomposition for the op is wrong or because a
# decomposition used by the particular op is wrong.
with self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all
) as mode, enable_python_dispatcher():
with (
self.DecompCrossRefMode(
self,
self.precision,
self.rel_tol,
dtype,
run_all,
) as mode,
enable_python_dispatcher(),
):
func(*args, **kwargs)
if run_without_python_dispatcher(mode):

View file

@ -255,9 +255,11 @@ class TestForeach(TestCase):
else inputs
)
try:
with InplaceForeachVersionBumpCheck(
self, inputs[0]
) if op.is_inplace else nullcontext():
with (
InplaceForeachVersionBumpCheck(self, inputs[0])
if op.is_inplace
else nullcontext()
):
actual = op(inputs, self.is_cuda, is_fastpath)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])):
@ -278,9 +280,11 @@ class TestForeach(TestCase):
try:
op_kwargs = {}
op_kwargs.update(kwargs)
with InplaceForeachVersionBumpCheck(
self, inputs[0]
) if op.is_inplace else nullcontext():
with (
InplaceForeachVersionBumpCheck(self, inputs[0])
if op.is_inplace
else nullcontext()
):
actual = op(inputs, self.is_cuda, is_fastpath, **op_kwargs)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])):

View file

@ -2782,8 +2782,10 @@ class TestFakeTensor(TestCase):
with torch._subclasses.CrossRefFakeMode(
ignore_op_fn=lambda fn: fn in common_skip_ops, check_aliasing=True
):
with warnings.catch_warnings(), context(), torch.autograd.set_multithreading_enabled(
False
with (
warnings.catch_warnings(),
context(),
torch.autograd.set_multithreading_enabled(False),
):
composite_compliance.compute_expected_grads(
op.get_op(),

View file

@ -160,6 +160,7 @@ def run_ruff_format(content: str, path: Path) -> str:
"format",
"--config",
str(REPO_ROOT / "pyproject.toml"),
"--target-version=py39",
"--stdin-filename",
str(path),
"-",

View file

@ -866,9 +866,10 @@ class TracingContext:
@contextlib.contextmanager
def clear_frame():
tc = TracingContext.get()
with unittest.mock.patch.object(
tc, "frame_summary_stack", []
), unittest.mock.patch.object(tc, "loc_in_frame", None):
with (
unittest.mock.patch.object(tc, "frame_summary_stack", []),
unittest.mock.patch.object(tc, "loc_in_frame", None),
):
try:
yield
except Exception as e:

View file

@ -834,7 +834,9 @@ class OpOverload(OperatorBase):
if curr_mode not in self.python_key_table:
if isinstance(self, TorchBindOpOverload):
with torch.utils._python_dispatch._pop_mode_temporarily() as mode:
with (
torch.utils._python_dispatch._pop_mode_temporarily() as mode
):
return torch._library.utils.handle_dispatch_mode(
mode, self, *args, **kwargs
)

View file

@ -564,9 +564,13 @@ class Exporter:
# https://github.com/pytorch/pytorch/issues/103764
from torch.onnx._internal.fx import decomposition_skip
with self.options.diagnostic_context, decomposition_skip.enable_decomposition_skips(
self.options
), torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)):
with (
self.options.diagnostic_context,
decomposition_skip.enable_decomposition_skips(self.options),
torch._dynamo.config.patch(
dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)
),
):
graph_module = self.options.fx_tracer.generate_fx(
self.options, self.model, self.model_args, self.model_kwargs
)

View file

@ -68,7 +68,11 @@ class Decompose(_pass.Transform):
# Apply decomposition table to the input graph.
assert fake_mode is not None # for mypy
with fake_tensor.unset_fake_temporarily(), python_dispatch.enable_python_dispatcher(), fake_mode:
with (
fake_tensor.unset_fake_temporarily(),
python_dispatch.enable_python_dispatcher(),
fake_mode,
):
decomposed_module = proxy_tensor.make_fx(
module,
decomposition_table=self.decomposition_table,

View file

@ -172,13 +172,12 @@ def exporter_context(model, mode: _C_onnx.TrainingMode, verbose: bool):
.. deprecated:: 2.7
Please set training mode before exporting the model.
"""
with select_model_mode_for_export(
model, mode
) as mode_ctx, disable_apex_o2_state_dict_hook(
model
) as apex_ctx, setup_onnx_logging(
verbose
) as log_ctx, diagnostics.create_export_diagnostic_context() as diagnostic_ctx:
with (
select_model_mode_for_export(model, mode) as mode_ctx,
disable_apex_o2_state_dict_hook(model) as apex_ctx,
setup_onnx_logging(verbose) as log_ctx,
diagnostics.create_export_diagnostic_context() as diagnostic_ctx,
):
yield (mode_ctx, apex_ctx, log_ctx, diagnostic_ctx)

View file

@ -1616,9 +1616,12 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
return saved_id[0]
return deserialized_objects[int(saved_id)]
with closing(
tarfile.open(fileobj=f, mode="r:", format=tarfile.PAX_FORMAT)
) as tar, mkdtemp() as tmpdir:
with (
closing(
tarfile.open(fileobj=f, mode="r:", format=tarfile.PAX_FORMAT)
) as tar,
mkdtemp() as tmpdir,
):
if pickle_module is _weights_only_unpickler:
raise RuntimeError(
"Cannot use ``weights_only=True`` with files saved in the "