diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 839f5e619fe..a200b61b23f 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -916,10 +916,13 @@ def latency_experiment(args, model_iter_fn, model, example_inputs, mark, **kwarg # inputs will incur high penalty then the next one. maybe_mark_step(args) - with maybe_mark_profile(p=p, mark=mark), maybe_enable_compiled_autograd( - args.compiled_autograd, - fullgraph=args.nopython, - dynamic=args.dynamic_shapes, + with ( + maybe_mark_profile(p=p, mark=mark), + maybe_enable_compiled_autograd( + args.compiled_autograd, + fullgraph=args.nopython, + dynamic=args.dynamic_shapes, + ), ): timings[rep], actual_output = timed( model, @@ -1090,10 +1093,13 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs): # call mark_step between the 2 calls to make the comparison fair. maybe_mark_step(args) - with maybe_mark_profile(p=p, mark="actual"), maybe_enable_compiled_autograd( - args.compiled_autograd, - fullgraph=args.nopython, - dynamic=args.dynamic_shapes, + with ( + maybe_mark_profile(p=p, mark="actual"), + maybe_enable_compiled_autograd( + args.compiled_autograd, + fullgraph=args.nopython, + dynamic=args.dynamic_shapes, + ), ): timings[rep, 1], actual_output = timed( model, @@ -2445,12 +2451,15 @@ class BenchmarkRunner: else: optimized_model_iter_fn = optimize_ctx(self.model_iter_fn) - with maybe_enable_compiled_autograd( - self.args.compiled_autograd, - fullgraph=self.args.nopython, - dynamic=self.args.dynamic_shapes, - ), maybe_snapshot_memory( - self.args.snapshot_memory, f"compiled_{self.args.only}" + with ( + maybe_enable_compiled_autograd( + self.args.compiled_autograd, + fullgraph=self.args.nopython, + dynamic=self.args.dynamic_shapes, + ), + maybe_snapshot_memory( + self.args.snapshot_memory, f"compiled_{self.args.only}" + ), ): dynamo_latency, dynamo_peak_mem, dynamo_stats = warmup( optimized_model_iter_fn, model, example_inputs, "dynamo" @@ -2598,12 +2607,15 @@ class BenchmarkRunner: else: optimized_model_iter_fn = optimize_ctx(self.model_iter_fn) - with maybe_enable_compiled_autograd( - self.args.compiled_autograd, - fullgraph=self.args.nopython, - dynamic=self.args.dynamic_shapes, - ), maybe_snapshot_memory( - self.args.snapshot_memory, f"compiled_{self.args.only}" + with ( + maybe_enable_compiled_autograd( + self.args.compiled_autograd, + fullgraph=self.args.nopython, + dynamic=self.args.dynamic_shapes, + ), + maybe_snapshot_memory( + self.args.snapshot_memory, f"compiled_{self.args.only}" + ), ): dynamo_latency, dynamo_peak_mem, dynamo_stats = warmup( optimized_model_iter_fn, model, example_inputs, "dynamo" diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/basic_modules_benchmarks.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/basic_modules_benchmarks.py index 0c69b0556ee..18d753b3a7c 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/basic_modules_benchmarks.py +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/basic_modules_benchmarks.py @@ -56,9 +56,11 @@ class Benchmark(BenchmarkBase): def _work(self): # enable_cpp_symbolic_shape_guards has impact on this benchmark # Keep using False value for consistency. - with fresh_inductor_cache(), torch._inductor.config.patch( - force_shape_pad=self._force_shape_pad - ), torch._dynamo.config.patch("enable_cpp_symbolic_shape_guards", False): + with ( + fresh_inductor_cache(), + torch._inductor.config.patch(force_shape_pad=self._force_shape_pad), + torch._dynamo.config.patch("enable_cpp_symbolic_shape_guards", False), + ): opt_m = torch.compile(backend=self.backend(), dynamic=self.is_dynamic())( self.m.cuda() if self._is_gpu else self.m ) diff --git a/pyproject.toml b/pyproject.toml index 602e54daf65..eb492fe6868 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ standard_library = ["typing_extensions"] [tool.ruff] +# TODO: remove manual set `target-version` in tools/linter/adapters/pyfmt_linter.py when we move to py39 target-version = "py38" line-length = 88 src = ["caffe2", "torch", "torchgen", "functorch", "test"] diff --git a/test/onnx/verify.py b/test/onnx/verify.py index 0dc2975df14..2e7ef39b434 100644 --- a/test/onnx/verify.py +++ b/test/onnx/verify.py @@ -513,8 +513,9 @@ def verify( "could mean that your network is numerically unstable. Otherwise\n" "it indicates a bug in PyTorch/ONNX; please file a bug report." ) - with Errors(msg, rtol=rtol, atol=atol) as errs, errs.addErrCtxt( - result_hint + with ( + Errors(msg, rtol=rtol, atol=atol) as errs, + errs.addErrCtxt(result_hint), ): for i, (x, y) in enumerate(zip(torch_out, backend_out)): errs.checkAlmostEqual(x.data.cpu().numpy(), y, f"In output {i}") diff --git a/test/test_decomp.py b/test/test_decomp.py index c0a0b66300c..7fd64da149d 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -580,9 +580,16 @@ class TestDecomp(TestCase): args = [sample_input.input] + list(sample_input.args) kwargs = sample_input.kwargs func = partial(op.get_op(), **kwargs) - with self.DecompCrossRefMode( - self, self.precision, self.rel_tol, dtype, run_all=False - ) as mode, enable_python_dispatcher(): + with ( + self.DecompCrossRefMode( + self, + self.precision, + self.rel_tol, + dtype, + run_all=False, + ) as mode, + enable_python_dispatcher(), + ): torch.autograd.gradcheck(func, args) self.check_decomposed(aten_name, mode) @@ -677,9 +684,16 @@ class TestDecomp(TestCase): module_input.forward_input.args, module_input.forward_input.kwargs, ) - with self.DecompCrossRefMode( - self, self.precision, self.rel_tol, dtype, run_all=True - ), enable_python_dispatcher(): + with ( + self.DecompCrossRefMode( + self, + self.precision, + self.rel_tol, + dtype, + run_all=True, + ), + enable_python_dispatcher(), + ): decomp_out = m(*args, **kwargs) non_decomp_out = m(*args, **kwargs) @@ -955,9 +969,16 @@ def forward(self, scores_1, mask_1, value_1): # store the called list on the mode object instance and no # explicit clearing is necessary as I will create a fresh mode # for each region - with self.DecompCrossRefMode( - self, self.precision, self.rel_tol, dtype, run_all - ) as mode, enable_python_dispatcher(): + with ( + self.DecompCrossRefMode( + self, + self.precision, + self.rel_tol, + dtype, + run_all, + ) as mode, + enable_python_dispatcher(), + ): decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals) if run_without_python_dispatcher(mode): # without this check, incorrect decomps at the python dispatcher level can still pass because @@ -974,9 +995,16 @@ def forward(self, scores_1, mask_1, value_1): ): cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out) - with self.DecompCrossRefMode( - self, self.precision, self.rel_tol, dtype, run_all - ) as mode, enable_python_dispatcher(): + with ( + self.DecompCrossRefMode( + self, + self.precision, + self.rel_tol, + dtype, + run_all, + ) as mode, + enable_python_dispatcher(), + ): decomp_vjp_fn(cotangents) if run_without_python_dispatcher(mode): # without this check, incorrect decomps at the python dispatcher level can still pass because @@ -993,9 +1021,16 @@ def forward(self, scores_1, mask_1, value_1): kwargs = sample_input.kwargs # A failure here might be because the decomposition for the op is wrong or because a # decomposition used by the particular op is wrong. - with self.DecompCrossRefMode( - self, self.precision, self.rel_tol, dtype, run_all - ) as mode, enable_python_dispatcher(): + with ( + self.DecompCrossRefMode( + self, + self.precision, + self.rel_tol, + dtype, + run_all, + ) as mode, + enable_python_dispatcher(), + ): func(*args, **kwargs) if run_without_python_dispatcher(mode): diff --git a/test/test_foreach.py b/test/test_foreach.py index c0c81e09e00..3c1ffcaebb7 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -255,9 +255,11 @@ class TestForeach(TestCase): else inputs ) try: - with InplaceForeachVersionBumpCheck( - self, inputs[0] - ) if op.is_inplace else nullcontext(): + with ( + InplaceForeachVersionBumpCheck(self, inputs[0]) + if op.is_inplace + else nullcontext() + ): actual = op(inputs, self.is_cuda, is_fastpath) except RuntimeError as e: with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])): @@ -278,9 +280,11 @@ class TestForeach(TestCase): try: op_kwargs = {} op_kwargs.update(kwargs) - with InplaceForeachVersionBumpCheck( - self, inputs[0] - ) if op.is_inplace else nullcontext(): + with ( + InplaceForeachVersionBumpCheck(self, inputs[0]) + if op.is_inplace + else nullcontext() + ): actual = op(inputs, self.is_cuda, is_fastpath, **op_kwargs) except RuntimeError as e: with self.assertRaisesRegex(type(e), re.escape(str(e).splitlines()[0])): diff --git a/test/test_ops.py b/test/test_ops.py index bce697498ac..0502bdeb811 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2782,8 +2782,10 @@ class TestFakeTensor(TestCase): with torch._subclasses.CrossRefFakeMode( ignore_op_fn=lambda fn: fn in common_skip_ops, check_aliasing=True ): - with warnings.catch_warnings(), context(), torch.autograd.set_multithreading_enabled( - False + with ( + warnings.catch_warnings(), + context(), + torch.autograd.set_multithreading_enabled(False), ): composite_compliance.compute_expected_grads( op.get_op(), diff --git a/tools/linter/adapters/pyfmt_linter.py b/tools/linter/adapters/pyfmt_linter.py index ae292100a06..7772816dcd0 100644 --- a/tools/linter/adapters/pyfmt_linter.py +++ b/tools/linter/adapters/pyfmt_linter.py @@ -160,6 +160,7 @@ def run_ruff_format(content: str, path: Path) -> str: "format", "--config", str(REPO_ROOT / "pyproject.toml"), + "--target-version=py39", "--stdin-filename", str(path), "-", diff --git a/torch/_guards.py b/torch/_guards.py index 53353abe752..e299e3e7eef 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -866,9 +866,10 @@ class TracingContext: @contextlib.contextmanager def clear_frame(): tc = TracingContext.get() - with unittest.mock.patch.object( - tc, "frame_summary_stack", [] - ), unittest.mock.patch.object(tc, "loc_in_frame", None): + with ( + unittest.mock.patch.object(tc, "frame_summary_stack", []), + unittest.mock.patch.object(tc, "loc_in_frame", None), + ): try: yield except Exception as e: diff --git a/torch/_ops.py b/torch/_ops.py index 4df2de10539..84b48ed1dec 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -834,7 +834,9 @@ class OpOverload(OperatorBase): if curr_mode not in self.python_key_table: if isinstance(self, TorchBindOpOverload): - with torch.utils._python_dispatch._pop_mode_temporarily() as mode: + with ( + torch.utils._python_dispatch._pop_mode_temporarily() as mode + ): return torch._library.utils.handle_dispatch_mode( mode, self, *args, **kwargs ) diff --git a/torch/onnx/_internal/_exporter_legacy.py b/torch/onnx/_internal/_exporter_legacy.py index ad196630f9d..8c697ad786f 100644 --- a/torch/onnx/_internal/_exporter_legacy.py +++ b/torch/onnx/_internal/_exporter_legacy.py @@ -564,9 +564,13 @@ class Exporter: # https://github.com/pytorch/pytorch/issues/103764 from torch.onnx._internal.fx import decomposition_skip - with self.options.diagnostic_context, decomposition_skip.enable_decomposition_skips( - self.options - ), torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)): + with ( + self.options.diagnostic_context, + decomposition_skip.enable_decomposition_skips(self.options), + torch._dynamo.config.patch( + dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG) + ), + ): graph_module = self.options.fx_tracer.generate_fx( self.options, self.model, self.model_args, self.model_kwargs ) diff --git a/torch/onnx/_internal/fx/passes/decomp.py b/torch/onnx/_internal/fx/passes/decomp.py index 06ec55c83da..9a8c9a9ddb1 100644 --- a/torch/onnx/_internal/fx/passes/decomp.py +++ b/torch/onnx/_internal/fx/passes/decomp.py @@ -68,7 +68,11 @@ class Decompose(_pass.Transform): # Apply decomposition table to the input graph. assert fake_mode is not None # for mypy - with fake_tensor.unset_fake_temporarily(), python_dispatch.enable_python_dispatcher(), fake_mode: + with ( + fake_tensor.unset_fake_temporarily(), + python_dispatch.enable_python_dispatcher(), + fake_mode, + ): decomposed_module = proxy_tensor.make_fx( module, decomposition_table=self.decomposition_table, diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 87ef237bb0f..2f6096ff6f5 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -172,13 +172,12 @@ def exporter_context(model, mode: _C_onnx.TrainingMode, verbose: bool): .. deprecated:: 2.7 Please set training mode before exporting the model. """ - with select_model_mode_for_export( - model, mode - ) as mode_ctx, disable_apex_o2_state_dict_hook( - model - ) as apex_ctx, setup_onnx_logging( - verbose - ) as log_ctx, diagnostics.create_export_diagnostic_context() as diagnostic_ctx: + with ( + select_model_mode_for_export(model, mode) as mode_ctx, + disable_apex_o2_state_dict_hook(model) as apex_ctx, + setup_onnx_logging(verbose) as log_ctx, + diagnostics.create_export_diagnostic_context() as diagnostic_ctx, + ): yield (mode_ctx, apex_ctx, log_ctx, diagnostic_ctx) diff --git a/torch/serialization.py b/torch/serialization.py index a9e352d413c..514cd318c1c 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -1616,9 +1616,12 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args): return saved_id[0] return deserialized_objects[int(saved_id)] - with closing( - tarfile.open(fileobj=f, mode="r:", format=tarfile.PAX_FORMAT) - ) as tar, mkdtemp() as tmpdir: + with ( + closing( + tarfile.open(fileobj=f, mode="r:", format=tarfile.PAX_FORMAT) + ) as tar, + mkdtemp() as tmpdir, + ): if pickle_module is _weights_only_unpickler: raise RuntimeError( "Cannot use ``weights_only=True`` with files saved in the "