diff --git a/.github/scripts/label_utils.py b/.github/scripts/label_utils.py index 8da0c49ba92..00c7cbf8e32 100644 --- a/.github/scripts/label_utils.py +++ b/.github/scripts/label_utils.py @@ -63,9 +63,9 @@ def gh_get_labels(org: str, repo: str) -> list[str]: update_labels(labels, info) last_page = get_last_page_num_from_header(header) - assert ( - last_page > 0 - ), "Error reading header info to determine total number of pages of labels" + assert last_page > 0, ( + "Error reading header info to determine total number of pages of labels" + ) for page_number in range(2, last_page + 1): # skip page 1 _, info = request_for_labels(prefix + f"&page={page_number}") update_labels(labels, info) diff --git a/.lintrunner.toml b/.lintrunner.toml index 347502fc07d..8b1439719b0 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1476,7 +1476,7 @@ init_command = [ 'black==23.12.1', 'usort==1.0.8.post1', 'isort==5.13.2', - 'ruff==0.8.4', # sync with RUFF + 'ruff==0.9.2', # sync with RUFF ] is_formatter = true @@ -1561,7 +1561,7 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', - 'ruff==0.8.4', # sync with PYFMT + 'ruff==0.9.2', # sync with PYFMT ] is_formatter = true diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index a200b61b23f..56616a44882 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -1006,9 +1006,9 @@ def latency_experiment_summary(suite_name, args, model, timings, **kwargs): row, ) c_headers, c_data = torch._dynamo.utils.compile_times(repr="csv", aggregate=True) - assert ( - output_filename.find(".csv") > 0 - ), f"expected output_filename to be a .csv, but got {output_filename}" + assert output_filename.find(".csv") > 0, ( + f"expected output_filename to be a .csv, but got {output_filename}" + ) write_outputs( output_filename[:-4] + "_compilation_metrics.csv", first_headers + c_headers, @@ -1182,9 +1182,9 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs): row, ) c_headers, c_data = torch._dynamo.utils.compile_times(repr="csv", aggregate=True) - assert ( - output_filename.find(".csv") > 0 - ), f"expected output_filename to be a .csv, but got {output_filename}" + assert output_filename.find(".csv") > 0, ( + f"expected output_filename to be a .csv, but got {output_filename}" + ) write_outputs( output_filename[:-4] + "_compilation_metrics.csv", first_headers + c_headers, @@ -1997,16 +1997,16 @@ class BenchmarkRunner: def deepcopy_and_maybe_parallelize(self, model): model = self.deepcopy_model(model) if self.args.ddp: - assert ( - torch.distributed.is_available() - ), "Can't use DDP without a distributed enabled build" + assert torch.distributed.is_available(), ( + "Can't use DDP without a distributed enabled build" + ) from torch.nn.parallel import DistributedDataParallel as DDP model = DDP(model, find_unused_parameters=True) elif self.args.fsdp: - assert ( - torch.distributed.is_available() - ), "Can't use FSDP without a distributed enabled build" + assert torch.distributed.is_available(), ( + "Can't use FSDP without a distributed enabled build" + ) from torch.distributed.fsdp import ( FullyShardedDataParallel as FSDP, MixedPrecision, @@ -2375,9 +2375,9 @@ class BenchmarkRunner: self, name, model, example_inputs, optimize_ctx, experiment, tag=None ): "Run performance test in non-alternately." - assert ( - experiment.func is latency_experiment - ), "Must run with latency_experiment." + assert experiment.func is latency_experiment, ( + "Must run with latency_experiment." + ) def warmup(fn, model, example_inputs, mode, niters=10): peak_mem = 0 diff --git a/benchmarks/dynamo/microbenchmarks/bench_mm_fusion.py b/benchmarks/dynamo/microbenchmarks/bench_mm_fusion.py index 51d47dcfd78..d33a98ddbbc 100644 --- a/benchmarks/dynamo/microbenchmarks/bench_mm_fusion.py +++ b/benchmarks/dynamo/microbenchmarks/bench_mm_fusion.py @@ -81,9 +81,9 @@ def bench(shape, layer_id, p, fusion_types=[""]): torch._dynamo.reset() torch._inductor.metrics.reset() triton_mm_ms, _, _ = benchmarker.benchmark_gpu(fn) - assert ( - torch._inductor.metrics.generated_kernel_count == 1 - ), "codegen #kernel != 1" + assert torch._inductor.metrics.generated_kernel_count == 1, ( + "codegen #kernel != 1" + ) row.extend([tflops(torch_mm_ms), tflops(triton_mm_ms)]) p.add_row(row) diff --git a/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py b/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py index 18a1878a18e..36a212625f1 100644 --- a/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py +++ b/benchmarks/dynamo/microbenchmarks/operator_inp_utils.py @@ -265,9 +265,9 @@ class OperatorInputsLoader: def get_inputs_for_operator( self, operator, dtype=None, device="cuda" ) -> Generator[tuple[Iterable[Any], dict[str, Any]], None, None]: - assert ( - str(operator) in self.operator_db - ), f"Could not find {operator}, must provide overload" + assert str(operator) in self.operator_db, ( + f"Could not find {operator}, must provide overload" + ) if "embedding" in str(operator): log.warning("Embedding inputs NYI, input data cannot be randomized") @@ -302,9 +302,9 @@ class OperatorInputsLoader: yield op def get_call_frequency(self, op): - assert ( - str(op) in self.operator_db - ), f"Could not find {op}, must provide overload" + assert str(op) in self.operator_db, ( + f"Could not find {op}, must provide overload" + ) count = 0 for counter in self.operator_db[str(op)].values(): diff --git a/benchmarks/functional_autograd_benchmark/torchaudio_models.py b/benchmarks/functional_autograd_benchmark/torchaudio_models.py index 184762a198d..40a3b853d6e 100644 --- a/benchmarks/functional_autograd_benchmark/torchaudio_models.py +++ b/benchmarks/functional_autograd_benchmark/torchaudio_models.py @@ -538,21 +538,21 @@ class MultiheadAttentionContainer(torch.nn.Module): query.size(-1), ) q, k, v = self.in_proj_container(query, key, value) - assert ( - q.size(-1) % self.nhead == 0 - ), "query's embed_dim must be divisible by the number of heads" + assert q.size(-1) % self.nhead == 0, ( + "query's embed_dim must be divisible by the number of heads" + ) head_dim = q.size(-1) // self.nhead q = q.reshape(tgt_len, bsz * self.nhead, head_dim) - assert ( - k.size(-1) % self.nhead == 0 - ), "key's embed_dim must be divisible by the number of heads" + assert k.size(-1) % self.nhead == 0, ( + "key's embed_dim must be divisible by the number of heads" + ) head_dim = k.size(-1) // self.nhead k = k.reshape(src_len, bsz * self.nhead, head_dim) - assert ( - v.size(-1) % self.nhead == 0 - ), "value's embed_dim must be divisible by the number of heads" + assert v.size(-1) % self.nhead == 0, ( + "value's embed_dim must be divisible by the number of heads" + ) head_dim = v.size(-1) // self.nhead v = v.reshape(src_len, bsz * self.nhead, head_dim) @@ -629,9 +629,9 @@ class ScaledDotProduct(torch.nn.Module): attn_mask = torch.nn.functional.pad(_attn_mask, [0, 1]) tgt_len, head_dim = query.size(-3), query.size(-1) - assert ( - query.size(-1) == key.size(-1) == value.size(-1) - ), "The feature dim of query, key, value must be equal." + assert query.size(-1) == key.size(-1) == value.size(-1), ( + "The feature dim of query, key, value must be equal." + ) assert key.size() == value.size(), "Shape of key, value must match" src_len = key.size(-3) batch_heads = max(query.size(-2), key.size(-2)) diff --git a/benchmarks/functional_autograd_benchmark/torchvision_models.py b/benchmarks/functional_autograd_benchmark/torchvision_models.py index 1d45701ab23..25dd91c02d6 100644 --- a/benchmarks/functional_autograd_benchmark/torchvision_models.py +++ b/benchmarks/functional_autograd_benchmark/torchvision_models.py @@ -884,9 +884,9 @@ class HungarianMatcher(nn.Module): self.cost_class = cost_class self.cost_bbox = cost_bbox self.cost_giou = cost_giou - assert ( - cost_class != 0 or cost_bbox != 0 or cost_giou != 0 - ), "all costs cant be 0" + assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, ( + "all costs cant be 0" + ) @torch.no_grad() def forward(self, outputs, targets): diff --git a/benchmarks/gpt_fast/model.py b/benchmarks/gpt_fast/model.py index bd438aea7a0..e1a675dcb76 100644 --- a/benchmarks/gpt_fast/model.py +++ b/benchmarks/gpt_fast/model.py @@ -51,9 +51,9 @@ class ModelArgs: # take longer name (as it have more symbols matched) if len(config) > 1: config.sort(key=len, reverse=True) - assert len(config[0]) != len( - config[1] - ), name # make sure only one 'best' match + assert len(config[0]) != len(config[1]), ( + name + ) # make sure only one 'best' match return cls(**transformer_configs[config[0]]) diff --git a/benchmarks/instruction_counts/core/expand.py b/benchmarks/instruction_counts/core/expand.py index e985c6bb982..d83b46af371 100644 --- a/benchmarks/instruction_counts/core/expand.py +++ b/benchmarks/instruction_counts/core/expand.py @@ -80,9 +80,9 @@ def _generate_torchscript_file(model_src: str, name: str) -> Optional[str]: # And again, the type checker has no way of knowing that this line is valid. jit_model = module.jit_model # type: ignore[attr-defined] - assert isinstance( - jit_model, (torch.jit.ScriptFunction, torch.jit.ScriptModule) - ), f"Expected ScriptFunction or ScriptModule, got: {type(jit_model)}" + assert isinstance(jit_model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)), ( + f"Expected ScriptFunction or ScriptModule, got: {type(jit_model)}" + ) jit_model.save(artifact_path) # type: ignore[call-arg] # Cleanup now that we have the actual serialized model. diff --git a/benchmarks/operator_benchmark/benchmark_core.py b/benchmarks/operator_benchmark/benchmark_core.py index fb4ded6c9fb..8d91f4bf475 100644 --- a/benchmarks/operator_benchmark/benchmark_core.py +++ b/benchmarks/operator_benchmark/benchmark_core.py @@ -276,9 +276,9 @@ class BenchmarkRunner: if c in open_to_close.keys(): curr_brackets.append(c) elif c in open_to_close.values(): - assert ( - curr_brackets and open_to_close[curr_brackets[-1]] == c - ), "ERROR: not able to parse the string!" + assert curr_brackets and open_to_close[curr_brackets[-1]] == c, ( + "ERROR: not able to parse the string!" + ) curr_brackets.pop() elif c == "," and (not curr_brackets): break_idxs.append(i) diff --git a/benchmarks/sparse/triton_ops.py b/benchmarks/sparse/triton_ops.py index 6f5fc44e8ef..48a88d592ea 100644 --- a/benchmarks/sparse/triton_ops.py +++ b/benchmarks/sparse/triton_ops.py @@ -3,9 +3,9 @@ from torch._inductor.runtime.benchmarking import benchmarker def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device): - assert ( - sparsity <= 1.0 and sparsity >= 0.0 - ), "sparsity should be a value between 0 and 1" + assert sparsity <= 1.0 and sparsity >= 0.0, ( + "sparsity should be a value between 0 and 1" + ) assert M % blocksize[0] == 0 assert N % blocksize[1] == 0 shape = (B, M // blocksize[0], N // blocksize[1])[int(B == 0) :] diff --git a/benchmarks/transformer/attention_bias_benchmarks.py b/benchmarks/transformer/attention_bias_benchmarks.py index c5bb2523e83..2154e11237e 100644 --- a/benchmarks/transformer/attention_bias_benchmarks.py +++ b/benchmarks/transformer/attention_bias_benchmarks.py @@ -84,9 +84,9 @@ class CompositeMHA(torch.nn.Module): self.head_dim = embed_dim // num_heads self.embed_dim = embed_dim - assert ( - self.head_dim * num_heads == self.embed_dim - ), "embed_dim must be divisible by num_heads" + assert self.head_dim * num_heads == self.embed_dim, ( + "embed_dim must be divisible by num_heads" + ) self.q_proj_weight = Parameter( torch.empty((embed_dim, embed_dim), **factory_kwargs) diff --git a/benchmarks/transformer/score_mod.py b/benchmarks/transformer/score_mod.py index 16c77fb7989..a2de7538898 100644 --- a/benchmarks/transformer/score_mod.py +++ b/benchmarks/transformer/score_mod.py @@ -49,9 +49,9 @@ class ExperimentConfig: backends: list[str] def __post_init__(self): - assert ( - len(self.shape) == 6 - ), "Shape must be of length 6" # [B, Hq, M, Hkv, N, D] + assert len(self.shape) == 6, ( + "Shape must be of length 6" + ) # [B, Hq, M, Hkv, N, D] def asdict(self): # Convert the dataclass instance to a dictionary diff --git a/functorch/dim/reference.py b/functorch/dim/reference.py index 01992cc5c12..5c6178c0981 100644 --- a/functorch/dim/reference.py +++ b/functorch/dim/reference.py @@ -625,9 +625,9 @@ def split(self, split_size_or_sections, dim=0): unbound.append(i) if unbound: - assert ( - total_bound_size <= size - ), f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})" + assert total_bound_size <= size, ( + f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})" + ) remaining_size = size - total_bound_size chunk_size = -(-remaining_size // len(unbound)) for u in unbound: @@ -636,9 +636,9 @@ def split(self, split_size_or_sections, dim=0): sizes[u] = sz remaining_size -= sz else: - assert ( - total_bound_size == size - ), f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})" + assert total_bound_size == size, ( + f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})" + ) return tuple( t.index(dim, d) for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim)) diff --git a/scripts/compile_tests/download_reports.py b/scripts/compile_tests/download_reports.py index fa9b43e02a3..03804b11f7e 100644 --- a/scripts/compile_tests/download_reports.py +++ b/scripts/compile_tests/download_reports.py @@ -62,9 +62,9 @@ def download_reports(commit_sha, configs=("dynamo39", "dynamo311", "eager311")): for config in configs: required_jobs.extend(list(CONFIGS[config])) for job in required_jobs: - assert ( - job in workflow_jobs - ), f"{job} not found, is the commit_sha correct? has the job finished running? The GitHub API may take a couple minutes to update." + assert job in workflow_jobs, ( + f"{job} not found, is the commit_sha correct? has the job finished running? The GitHub API may take a couple minutes to update." + ) # This page lists all artifacts. listings = requests.get( diff --git a/scripts/export/update_schema.py b/scripts/export/update_schema.py index 904cf2b7d8c..fa2a54f364f 100644 --- a/scripts/export/update_schema.py +++ b/scripts/export/update_schema.py @@ -23,9 +23,9 @@ if __name__ == "__main__": ) args = parser.parse_args() - assert os.path.exists( - args.prefix - ), f"Assuming path {args.prefix} is the root of pytorch directory, but it doesn't exist." + assert os.path.exists(args.prefix), ( + f"Assuming path {args.prefix} is the root of pytorch directory, but it doesn't exist." + ) commit = schema_check.update_schema() @@ -40,7 +40,9 @@ if __name__ == "__main__": f"Treespec version downgraded from {commit.base['TREESPEC_VERSION']} to {commit.result['TREESPEC_VERSION']}." ) else: - assert args.force_unsafe, "Existing schema yaml file not found, please use --force-unsafe to try again." + assert args.force_unsafe, ( + "Existing schema yaml file not found, please use --force-unsafe to try again." + ) next_version, reason = schema_check.check(commit, args.force_unsafe) diff --git a/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py b/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py index 35e725ca71c..ef8fdff4bcb 100644 --- a/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py +++ b/test/onnx/dynamo/test_dynamo_with_onnxruntime_backend.py @@ -182,9 +182,9 @@ class TestDynamoWithONNXRuntime(onnx_test_common._TestONNXRuntime): baseline_param.grad, param.grad, atol=atol, rtol=rtol ) else: - assert ( - test_backward is False - ), "Calculating backward with multiple outputs is not supported yet." + assert test_backward is False, ( + "Calculating backward with multiple outputs is not supported yet." + ) for baseline_elem, result_elem in zip(baseline_result, result): torch.testing.assert_close( baseline_elem, result_elem, atol=atol, rtol=rtol diff --git a/test/onnx/exporter/test_hf_models_e2e.py b/test/onnx/exporter/test_hf_models_e2e.py index 08e33681d40..66690f17c88 100644 --- a/test/onnx/exporter/test_hf_models_e2e.py +++ b/test/onnx/exporter/test_hf_models_e2e.py @@ -30,15 +30,13 @@ class DynamoExporterTest(common_utils.TestCase): onnx_testing.assert_onnx_program(onnx_program) -def _prepare_llm_model_gptj_to_test() -> ( - tuple[ - torch.nn.Module, - dict[str, Any], - dict[str, dict[int, str]], - list[str], - list[str], - ] -): +def _prepare_llm_model_gptj_to_test() -> tuple[ + torch.nn.Module, + dict[str, Any], + dict[str, dict[int, str]], + list[str], + list[str], +]: model = transformers.GPTJForCausalLM.from_pretrained( "hf-internal-testing/tiny-random-gptj" ) diff --git a/test/onnx/onnx_test_common.py b/test/onnx/onnx_test_common.py index 9bdce4de117..d446130cfc4 100644 --- a/test/onnx/onnx_test_common.py +++ b/test/onnx/onnx_test_common.py @@ -93,9 +93,9 @@ def assert_dynamic_shapes(onnx_program: torch.onnx.ONNXProgram, dynamic_shapes: for dim in inp.type.tensor_type.shape.dim if dim.dim_value == 0 and dim.dim_param != "" ] - assert dynamic_shapes == ( - len(dynamic_inputs) > 0 - ), "Dynamic shape check failed for graph inputs" + assert dynamic_shapes == (len(dynamic_inputs) > 0), ( + "Dynamic shape check failed for graph inputs" + ) def parameterize_class_name(cls: type, idx: int, input_dicts: Mapping[Any, Any]): @@ -249,9 +249,9 @@ class _TestONNXRuntime(pytorch_test_common.ExportTestCase): ref_input_args = input_args ref_input_kwargs = input_kwargs - assert isinstance(ref_model, torch.nn.Module) or callable( - ref_model - ), "Model must be a torch.nn.Module or callable" + assert isinstance(ref_model, torch.nn.Module) or callable(ref_model), ( + "Model must be a torch.nn.Module or callable" + ) if ( self.model_type == pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM @@ -640,9 +640,9 @@ def add_decorate_info( # Skip does not apply to this opset continue opinfo = ops_mapping.get((decorate_meta.op_name, decorate_meta.variant_name)) - assert ( - opinfo is not None - ), f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?" + assert opinfo is not None, ( + f"Couldn't find OpInfo for {decorate_meta}. Did you need to specify variant_name?" + ) assert decorate_meta.model_type is None, ( f"Tested op: {decorate_meta.op_name} in wrong position! " "If model_type needs to be specified, it should be " diff --git a/test/onnx/pytorch_test_common.py b/test/onnx/pytorch_test_common.py index 408168a9c71..1cdae000eda 100644 --- a/test/onnx/pytorch_test_common.py +++ b/test/onnx/pytorch_test_common.py @@ -294,13 +294,13 @@ def xfail(error_message: str, reason: Optional[str] = None): except Exception as e: if isinstance(e, torch.onnx.OnnxExporterError): # diagnostic message is in the cause of the exception - assert ( - error_message in str(e.__cause__) - ), f"Expected error message: {error_message} NOT in {str(e.__cause__)}" + assert error_message in str(e.__cause__), ( + f"Expected error message: {error_message} NOT in {str(e.__cause__)}" + ) else: - assert error_message in str( - e - ), f"Expected error message: {error_message} NOT in {str(e)}" + assert error_message in str(e), ( + f"Expected error message: {error_message} NOT in {str(e)}" + ) pytest.xfail(reason if reason else f"Expected failure: {error_message}") else: pytest.fail("Unexpected success!") diff --git a/test/onnx/test_fx_passes.py b/test/onnx/test_fx_passes.py index 6fc3c643175..97d255abdcb 100644 --- a/test/onnx/test_fx_passes.py +++ b/test/onnx/test_fx_passes.py @@ -31,9 +31,9 @@ class TestFxPasses(common_utils.TestCase): name_to_node = {node.name: node for node in gm.graph.nodes} pass_utils.set_node_name(nodes[0], base_name, name_to_node) assert nodes[0].name == base_name, f"Expected {base_name}, got {nodes[0].name}" - assert len({node.name for node in nodes}) == len( - nodes - ), f"Expected all names to be unique, got {nodes}" + assert len({node.name for node in nodes}) == len(nodes), ( + f"Expected all names to be unique, got {nodes}" + ) def test_set_node_name_succeeds_when_no_name_collisions(self): def func(x, y, z): @@ -51,9 +51,9 @@ class TestFxPasses(common_utils.TestCase): name_to_node = {node.name: node for node in nodes} pass_utils.set_node_name(nodes[1], new_name, name_to_node) assert nodes[1].name == new_name, f"Expected {new_name}, got {nodes[0].name}" - assert len({node.name for node in nodes}) == len( - nodes - ), f"Expected all names to be unique, got {nodes}" + assert len({node.name for node in nodes}) == len(nodes), ( + f"Expected all names to be unique, got {nodes}" + ) if __name__ == "__main__": diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index abb647af6b9..24d3ffd50e2 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -12630,12 +12630,12 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime): actual_std = np.std(ort_out) actual_mean = np.mean(ort_out) - assert ( - abs(abs(actual_mean) - expected_mean) <= expected_mean * 0.1 - ), "the gap of mean between ort outputs and expected one is unacceptable." - assert ( - abs(abs(actual_std) - expected_std) <= expected_std * 0.1 - ), "the gap of variance between ort outputs and expected one is unacceptable." + assert abs(abs(actual_mean) - expected_mean) <= expected_mean * 0.1, ( + "the gap of mean between ort outputs and expected one is unacceptable." + ) + assert abs(abs(actual_std) - expected_std) <= expected_std * 0.1, ( + "the gap of variance between ort outputs and expected one is unacceptable." + ) @skipScriptTest() @skipIfUnsupportedMinOpsetVersion(11) @@ -12661,12 +12661,12 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime): actual_std = np.std(ort_out) actual_mean = np.mean(ort_out) - assert ( - abs(abs(actual_mean) - expected_mean) <= expected_mean * 0.1 - ), "the gap of mean between ort outputs and expected one is unacceptable." - assert ( - abs(abs(actual_std) - expected_std) <= expected_std * 0.1 - ), "the gap of variance between ort outputs and expected one is unacceptable." + assert abs(abs(actual_mean) - expected_mean) <= expected_mean * 0.1, ( + "the gap of mean between ort outputs and expected one is unacceptable." + ) + assert abs(abs(actual_std) - expected_std) <= expected_std * 0.1, ( + "the gap of variance between ort outputs and expected one is unacceptable." + ) @skipScriptTest() @skipIfUnsupportedMinOpsetVersion(11) @@ -12705,15 +12705,15 @@ class TestONNXRuntime(onnx_test_common._TestONNXRuntime): actual_max = np.max(ort_out) actual_mean = np.mean(ort_out) - assert ( - actual_min >= expected_min - ), "the minimum value of ort outputs is out of scope." - assert ( - actual_max <= expected_max - ), "the maximum value of ort outputs is out of scope." - assert ( - abs(actual_mean - expected_mean) <= expected_mean * 0.05 - ), "the mean value of ort outputs is out of scope." + assert actual_min >= expected_min, ( + "the minimum value of ort outputs is out of scope." + ) + assert actual_max <= expected_max, ( + "the maximum value of ort outputs is out of scope." + ) + assert abs(actual_mean - expected_mean) <= expected_mean * 0.05, ( + "the mean value of ort outputs is out of scope." + ) @skipIfUnsupportedMinOpsetVersion(13) def test_sequence_to_int(self): diff --git a/test/run_test.py b/test/run_test.py index 6b7469c4f04..feca1a2ed02 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -1441,9 +1441,9 @@ def get_pytest_args(options, is_cpp_test=False, is_distributed_test=False): def run_ci_sanity_check(test: ShardedTest, test_directory, options): - assert ( - test.name == "test_ci_sanity_check_fail" - ), f"This handler only works for test_ci_sanity_check_fail, got {test.name}" + assert test.name == "test_ci_sanity_check_fail", ( + f"This handler only works for test_ci_sanity_check_fail, got {test.name}" + ) ret_code = run_test(test, test_directory, options, print_log=False) # This test should fail if ret_code != 1: @@ -1951,9 +1951,9 @@ def get_sharding_opts(options) -> tuple[int, int]: assert len(options.shard) == 2, "Unexpected shard format" assert min(options.shard) > 0, "Shards must be positive numbers" which_shard, num_shards = options.shard - assert ( - which_shard <= num_shards - ), "Selected shard must be less than or equal to total number of shards" + assert which_shard <= num_shards, ( + "Selected shard must be less than or equal to total number of shards" + ) return (which_shard, num_shards) @@ -1996,9 +1996,9 @@ def run_test_module( print_to_stderr(f"Running {str(test)} ... [{datetime.now()}]") handler = CUSTOM_HANDLERS.get(test_name, run_test) return_code = handler(test, test_directory, options) - assert isinstance(return_code, int) and not isinstance( - return_code, bool - ), f"While running {str(test)} got non integer return code {return_code}" + assert isinstance(return_code, int) and not isinstance(return_code, bool), ( + f"While running {str(test)} got non integer return code {return_code}" + ) if return_code == 0: return None diff --git a/test/test_dataloader.py b/test/test_dataloader.py index b609ad533cf..3edf0163288 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -665,12 +665,12 @@ class ErrorTrackingProcess(mp.Process): raise def print_traces_of_all_threads(self): - assert ( - self.is_alive() - ), "can only use print_traces_of_all_threads if the process is alive" - assert ( - not self.disable_stderr - ), "do not disable stderr if you use print_traces_of_all_threads" + assert self.is_alive(), ( + "can only use print_traces_of_all_threads if the process is alive" + ) + assert not self.disable_stderr, ( + "do not disable stderr if you use print_traces_of_all_threads" + ) # On platforms without `SIGUSR1`, `set_faulthander_if_available` sets # `faulthandler.enable()`, and `print_traces_of_all_threads` may kill # the process. So let's poll the exception first @@ -1030,19 +1030,19 @@ class TestWorkerInfoDataset(SynchronizedDataset): # See _test_get_worker_info below for usage. def _test_worker_info_init_fn(worker_id): worker_info = torch.utils.data.get_worker_info() - assert ( - worker_id == worker_info.id - ), "worker_init_fn and worker_info should have consistent id" - assert ( - worker_id < worker_info.num_workers - ), "worker_init_fn and worker_info should have valid id" - assert ( - worker_info.seed == torch.initial_seed() - ), "worker_init_fn and worker_info should have consistent seed" + assert worker_id == worker_info.id, ( + "worker_init_fn and worker_info should have consistent id" + ) + assert worker_id < worker_info.num_workers, ( + "worker_init_fn and worker_info should have valid id" + ) + assert worker_info.seed == torch.initial_seed(), ( + "worker_init_fn and worker_info should have consistent seed" + ) dataset = worker_info.dataset - assert isinstance( - dataset, TestWorkerInfoDataset - ), "worker_info should have correct dataset copy" + assert isinstance(dataset, TestWorkerInfoDataset), ( + "worker_info should have correct dataset copy" + ) assert not hasattr(dataset, "value"), "worker_info should have correct dataset copy" # test that WorkerInfo attributes are read-only try: diff --git a/test/test_foreach.py b/test/test_foreach.py index 3c1ffcaebb7..9c0fc438d13 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -87,9 +87,9 @@ class ForeachFuncWrapper: actual = self.func(*inputs, **kwargs) keys = tuple([e.key for e in p.key_averages()]) mta_called = any("multi_tensor_apply_kernel" in k for k in keys) - assert ( - mta_called == (expect_fastpath and (not zero_size)) - ), f"{mta_called=}, {expect_fastpath=}, {zero_size=}, {self.func.__name__=}, {keys=}" + assert mta_called == (expect_fastpath and (not zero_size)), ( + f"{mta_called=}, {expect_fastpath=}, {zero_size=}, {self.func.__name__=}, {keys=}" + ) else: actual = self.func(*inputs, **kwargs) if self.is_inplace: diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 4c50a00d7de..d08141981c4 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -169,9 +169,9 @@ def random_nt( assert max_dim > min_dim, "random_nt: max_dim must be greater than min_dim" assert min_dim >= 0, "random_nt: min_dim must be non-negative" if require_non_empty: - assert not ( - min_dim == 0 and max_dim == 1 - ), "random_nt: zero cannot be the only possible value if require_non_empty is True" + assert not (min_dim == 0 and max_dim == 1), ( + "random_nt: zero cannot be the only possible value if require_non_empty is True" + ) if require_non_empty: # Select a random idx that will be required to be non-empty diff --git a/test/test_optim.py b/test/test_optim.py index da6ddd6f724..2dcf3faecd6 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -52,17 +52,17 @@ FP16_REDUCED_PRECISION = {"atol": 1e-5, "rtol": 1e-4} def rosenbrock(tensor): - assert tensor.size() == torch.Size( - [2] - ), f"Requires tensor with 2 scalars but got {tensor.size()}" + assert tensor.size() == torch.Size([2]), ( + f"Requires tensor with 2 scalars but got {tensor.size()}" + ) x, y = tensor return (1 - x) ** 2 + 100 * (y - x**2) ** 2 def drosenbrock(tensor): - assert tensor.size() == torch.Size( - [2] - ), f"Requires tensor with 2 scalars but got {tensor.size()}" + assert tensor.size() == torch.Size([2]), ( + f"Requires tensor with 2 scalars but got {tensor.size()}" + ) x, y = tensor return torch.stack((-400 * x * (y - x**2) - 2 * (1 - x), 200 * (y - x**2))) diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 178fc9b1111..f1e0140a415 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -616,9 +616,9 @@ def load_deprecated_signatures( } schema_args_by_name = {a.name: a for a in schema.arguments.flat_all} for name in call_args: - assert ( - name in schema_args_by_name or name in known_constants - ), f"deprecation definiton: Unrecognized value {name}" + assert name in schema_args_by_name or name in known_constants, ( + f"deprecation definiton: Unrecognized value {name}" + ) # Map deprecated signature arguments to their aten signature and test # if the types and alias annotation match. @@ -683,7 +683,9 @@ def load_deprecated_signatures( function=pair.function, ) ) - assert any_schema_found, f"No native function with name {aten_name} matched signature:\n {str(schema)}" + assert any_schema_found, ( + f"No native function with name {aten_name} matched signature:\n {str(schema)}" + ) return results diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 35030359203..ed5a6e6cf39 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -1075,7 +1075,9 @@ def emit_body( assert ( base_name_and_overload_name in _foreach_ops_without_differentiability_info - ), f"{'.'.join(base_name_and_overload_name)} should have a differentiability info" + ), ( + f"{'.'.join(base_name_and_overload_name)} should have a differentiability info" + ) else: assert ( len(f.func.arguments.flat_non_out) @@ -1634,9 +1636,9 @@ def emit_body( noref_cpp_type = cpp.return_type(ret, symint=True).remove_const_ref() if noref_cpp_type == BaseCType(tensorT): if aliased_arg_name is not None: - assert ( - i == 0 - ), "Expect non-CompositeImplicitAutograd view function {base} to return single output" + assert i == 0, ( + "Expect non-CompositeImplicitAutograd view function {base} to return single output" + ) stmts_after_call += [ ENFORCE_SAME_TENSOR_STORAGE.substitute( tensor_name=aliased_arg_name, out_tensor_name=ret_name @@ -1873,9 +1875,9 @@ def emit_body( for derivative in fw_derivatives: res = derivative.var_names if f.func.name.name.inplace: - assert ( - len(res) == 1 - ), "Expected number of outputs to be 1 if function is inplace" + assert len(res) == 1, ( + "Expected number of outputs to be 1 if function is inplace" + ) # TODO update this when inplace namings are unified res = ("self",) @@ -1973,9 +1975,9 @@ def emit_body( isinstance(derivative.var_types[0], ListType) and derivative.var_types[0].is_tensor_like() ): - assert ( - len(derivative.var_types) == 1 - ), "Expected number of outputs to be 1 if function returns ListType" + assert len(derivative.var_types) == 1, ( + "Expected number of outputs to be 1 if function returns ListType" + ) if not is_foreach: opt_res_grad_type = OptionalCType( VectorCType(BaseCType(tensorT)) diff --git a/tools/autograd/load_derivatives.py b/tools/autograd/load_derivatives.py index fbab7bc5aff..9d600a81575 100644 --- a/tools/autograd/load_derivatives.py +++ b/tools/autograd/load_derivatives.py @@ -313,9 +313,9 @@ def postprocess_forward_derivatives( formula = defn.formula required_inputs_tangent = find_required_inputs(formula, "_t") if formula == "auto_element_wise": - assert ( - f.func.kind() != SchemaKind.inplace - ), f"Cannot use auto_element_wise with {f.func.name} because it is an in-place variant" + assert f.func.kind() != SchemaKind.inplace, ( + f"Cannot use auto_element_wise with {f.func.name} because it is an in-place variant" + ) if ( (not len(args_with_derivatives) == 1) or len(forward_derivatives) > 1 diff --git a/tools/code_coverage/package/tool/parser/llvm_coverage_segment.py b/tools/code_coverage/package/tool/parser/llvm_coverage_segment.py index 63b1e4baf51..b9abd3fafed 100644 --- a/tools/code_coverage/package/tool/parser/llvm_coverage_segment.py +++ b/tools/code_coverage/package/tool/parser/llvm_coverage_segment.py @@ -41,9 +41,9 @@ def parse_segments(raw_segments: list[list[int]]) -> list[LlvmCoverageSegment]: """ ret: list[LlvmCoverageSegment] = [] for raw_segment in raw_segments: - assert ( - len(raw_segment) == 5 or len(raw_segment) == 6 - ), "list is not compatible with llvmcom export:" + assert len(raw_segment) == 5 or len(raw_segment) == 6, ( + "list is not compatible with llvmcom export:" + ) " Expected to have 5 or 6 elements" if len(raw_segment) == 5: ret.append( diff --git a/tools/flight_recorder/components/builder.py b/tools/flight_recorder/components/builder.py index d6b60bc9620..f4f116bab7a 100644 --- a/tools/flight_recorder/components/builder.py +++ b/tools/flight_recorder/components/builder.py @@ -116,12 +116,12 @@ def build_groups_memberships( _memberships[pg_guid] = set(ranks) else: # validation across ranks - assert ( - _groups[pg_guid].desc == desc - ), f"mismatch in desc {_groups[pg_guid].desc} vs {desc} for group {pg_guid}" - assert ( - _memberships[pg_guid] == set(ranks) - ), f"mismatch in membership for group {pg_guid} {_memberships[pg_guid]} vs {set(ranks)}" + assert _groups[pg_guid].desc == desc, ( + f"mismatch in desc {_groups[pg_guid].desc} vs {desc} for group {pg_guid}" + ) + assert _memberships[pg_guid] == set(ranks), ( + f"mismatch in membership for group {pg_guid} {_memberships[pg_guid]} vs {set(ranks)}" + ) return groups, _groups, memberships, _memberships, _pg_guids diff --git a/tools/flight_recorder/components/config_manager.py b/tools/flight_recorder/components/config_manager.py index 48b3687a8aa..ea9b0cf3918 100644 --- a/tools/flight_recorder/components/config_manager.py +++ b/tools/flight_recorder/components/config_manager.py @@ -73,13 +73,13 @@ class JobConfig: ) -> argparse.Namespace: args = self.parser.parse_args(args) if args.selected_ranks is not None: - assert ( - args.just_print_entries - ), "Not support selecting ranks without printing entries" + assert args.just_print_entries, ( + "Not support selecting ranks without printing entries" + ) if args.pg_filters is not None: - assert ( - args.just_print_entries - ), "Not support selecting pg filters without printing entries" + assert args.just_print_entries, ( + "Not support selecting pg filters without printing entries" + ) if args.verbose: logger.set_log_level(logging.DEBUG) return args diff --git a/tools/flight_recorder/components/loader.py b/tools/flight_recorder/components/loader.py index 54081fbaffa..d836779b585 100644 --- a/tools/flight_recorder/components/loader.py +++ b/tools/flight_recorder/components/loader.py @@ -85,8 +85,8 @@ def read_dir(args: argparse.Namespace) -> tuple[dict[str, dict[str, Any]], str]: if not version: version = str(details[f]["version"]) tb = time.time() - assert ( - len(details) > 0 - ), f"no files loaded from {args.trace_dir} with prefix {prefix}" + assert len(details) > 0, ( + f"no files loaded from {args.trace_dir} with prefix {prefix}" + ) logger.debug("loaded %s files in %ss", filecount, tb - t0) return details, version diff --git a/tools/flight_recorder/components/types.py b/tools/flight_recorder/components/types.py index 2cd3c6292ca..4fd412cda95 100644 --- a/tools/flight_recorder/components/types.py +++ b/tools/flight_recorder/components/types.py @@ -378,9 +378,9 @@ class Op: meta = parts[1] if len(parts) == 2 else None self.state = event["state"] self.pg_name, self.pg_desc = event["process_group"] - assert type in COLLECTIVES | P2P | { - "coalesced" - }, f"{type} is not a supported operation" + assert type in COLLECTIVES | P2P | {"coalesced"}, ( + f"{type} is not a supported operation" + ) self.type = type if type == "send": assert isinstance(meta, str) diff --git a/tools/flight_recorder/components/utils.py b/tools/flight_recorder/components/utils.py index 3933d9a8baa..8f8bf9d7b5a 100644 --- a/tools/flight_recorder/components/utils.py +++ b/tools/flight_recorder/components/utils.py @@ -263,16 +263,16 @@ def check_no_missing_dump_files( for membership in memberships: all_ranks.add(int(membership.global_rank)) dumps_ranks = {int(key) for key in entries.keys()} - assert ( - dumps_ranks == all_ranks - ), f"Missing dump files from ranks {all_ranks - dumps_ranks}" + assert dumps_ranks == all_ranks, ( + f"Missing dump files from ranks {all_ranks - dumps_ranks}" + ) def check_version(version_by_ranks: dict[str, str], version: str) -> None: for rank, v in version_by_ranks.items(): - assert ( - v == version - ), f"Rank {rank} has different version {v} from the given version {version}" + assert v == version, ( + f"Rank {rank} has different version {v} from the given version {version}" + ) def get_version_detail(version: str) -> tuple[int, int]: diff --git a/tools/onnx/gen_diagnostics.py b/tools/onnx/gen_diagnostics.py index 8f9d82f30f5..df01754bb60 100644 --- a/tools/onnx/gen_diagnostics.py +++ b/tools/onnx/gen_diagnostics.py @@ -102,13 +102,13 @@ def _format_rule_for_python_class(rule: _RuleType) -> str: if field_name is not None ] for field_name in field_names: - assert isinstance( - field_name, str - ), f"Unexpected field type {type(field_name)} from {field_name}. " + assert isinstance(field_name, str), ( + f"Unexpected field type {type(field_name)} from {field_name}. " + ) "Field name must be string.\nFull message template: {message_template}" - assert ( - not field_name.isnumeric() - ), f"Unexpected numeric field name {field_name}. " + assert not field_name.isnumeric(), ( + f"Unexpected numeric field name {field_name}. " + ) "Only keyword name formatting is supported.\nFull message template: {message_template}" message_arguments = ", ".join(field_names) message_arguments_assigned = ", ".join( diff --git a/tools/setup_helpers/generate_code.py b/tools/setup_helpers/generate_code.py index 6e0a64888f0..64a12c0d228 100644 --- a/tools/setup_helpers/generate_code.py +++ b/tools/setup_helpers/generate_code.py @@ -212,12 +212,12 @@ def main() -> None: lazy_install_dir = os.path.join(install_dir, "lazy/generated") os.makedirs(lazy_install_dir, exist_ok=True) - assert os.path.isfile( - ts_backend_yaml - ), f"Unable to access ts_backend_yaml: {ts_backend_yaml}" - assert os.path.isfile( - ts_native_functions - ), f"Unable to access {ts_native_functions}" + assert os.path.isfile(ts_backend_yaml), ( + f"Unable to access ts_backend_yaml: {ts_backend_yaml}" + ) + assert os.path.isfile(ts_native_functions), ( + f"Unable to access {ts_native_functions}" + ) from torchgen.dest.lazy_ir import GenTSLazyIR from torchgen.gen_lazy_tensor import run_gen_lazy_tensor diff --git a/tools/test/gen_operators_yaml_test.py b/tools/test/gen_operators_yaml_test.py index ef129974feb..815c8bf9fb5 100644 --- a/tools/test/gen_operators_yaml_test.py +++ b/tools/test/gen_operators_yaml_test.py @@ -94,9 +94,9 @@ class GenOperatorsYAMLTest(unittest.TestCase): ] filtered_configs = list(filter(filter_func, config)) - assert ( - len(filtered_configs) == 2 - ), f"Expected 2 elements in filtered_configs, but got {len(filtered_configs)}" + assert len(filtered_configs) == 2, ( + f"Expected 2 elements in filtered_configs, but got {len(filtered_configs)}" + ) def test_verification_success(self) -> None: filter_func = make_filter_from_options( diff --git a/tools/testing/target_determination/heuristics/interface.py b/tools/testing/target_determination/heuristics/interface.py index e665be3ce87..48fbfa342a9 100644 --- a/tools/testing/target_determination/heuristics/interface.py +++ b/tools/testing/target_determination/heuristics/interface.py @@ -48,20 +48,20 @@ class TestPrioritizations: if test.test_file not in files: files[test.test_file] = copy(test) else: - assert ( - files[test.test_file] & test - ).is_empty(), ( + assert (files[test.test_file] & test).is_empty(), ( f"Test run `{test}` overlaps with `{files[test.test_file]}`" ) files[test.test_file] |= test for test in files.values(): - assert test.is_full_file(), f"All includes should have been excluded elsewhere, and vice versa. Test run `{test}` violates that" # noqa: B950 + assert test.is_full_file(), ( + f"All includes should have been excluded elsewhere, and vice versa. Test run `{test}` violates that" + ) # noqa: B950 # Ensure that the set of tests in the TestPrioritizations is identical to the set of tests passed in - assert ( - self._original_tests == set(files.keys()) - ), "The set of tests in the TestPrioritizations must be identical to the set of tests passed in" + assert self._original_tests == set(files.keys()), ( + "The set of tests in the TestPrioritizations must be identical to the set of tests passed in" + ) def _traverse_scores(self) -> Iterator[tuple[float, TestRun]]: # Sort by score, then alphabetically by test name @@ -228,9 +228,9 @@ class AggregatedHeuristics: def validate(self) -> None: for heuristic, heuristic_results in self._heuristic_results.items(): heuristic_results.validate() - assert ( - heuristic_results._original_tests == self._all_tests - ), f"Tests in {heuristic.name} are not the same as the tests in the AggregatedHeuristics" + assert heuristic_results._original_tests == self._all_tests, ( + f"Tests in {heuristic.name} are not the same as the tests in the AggregatedHeuristics" + ) def add_heuristic_results( self, heuristic: HeuristicInterface, heuristic_results: TestPrioritizations diff --git a/tools/testing/test_run.py b/tools/testing/test_run.py index c6637fce5fa..81bdfc4d708 100644 --- a/tools/testing/test_run.py +++ b/tools/testing/test_run.py @@ -43,9 +43,9 @@ class TestRun: exs = set(excluded or []) if "::" in name: - assert ( - not included and not excluded - ), "Can't specify included or excluded tests when specifying a test class in the file name" + assert not included and not excluded, ( + "Can't specify included or excluded tests when specifying a test class in the file name" + ) self.test_file, test_class = name.split("::") ins.add(test_class) else: @@ -148,9 +148,9 @@ class TestRun: return copy(self) # If not, ensure we have the same file - assert ( - self.test_file == other.test_file - ), f"Can't exclude {other} from {self} because they're not the same test file" + assert self.test_file == other.test_file, ( + f"Can't exclude {other} from {self} because they're not the same test file" + ) # 4 possible cases: diff --git a/tools/testing/test_selections.py b/tools/testing/test_selections.py index 635f512f087..fdecce02a0b 100644 --- a/tools/testing/test_selections.py +++ b/tools/testing/test_selections.py @@ -124,9 +124,9 @@ def get_duration( if included: return included_classes_duration - assert ( - excluded - ), f"TestRun {test} is not full file but doesn't have included or excluded classes" + assert excluded, ( + f"TestRun {test} is not full file but doesn't have included or excluded classes" + ) if file_duration is None: return None return file_duration - excluded_classes_duration @@ -140,9 +140,9 @@ def shard( ) -> None: # Modifies sharded_jobs in place if len(sharded_jobs) == 0: - assert ( - len(pytest_sharded_tests) == 0 - ), "No shards provided but there are tests to shard" + assert len(pytest_sharded_tests) == 0, ( + "No shards provided but there are tests to shard" + ) return round_robin_index = 0 diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 4a57a63e184..c3e08575194 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -255,9 +255,9 @@ def createResolutionCallbackFromEnv(lookup_base): def parseExpr(expr, module): try: value, len_parsed = parseNestedExpr(expr, module) - assert len_parsed == len( - expr - ), "whole expression was not parsed, falling back to c++ parser" + assert len_parsed == len(expr), ( + "whole expression was not parsed, falling back to c++ parser" + ) return value except Exception: """ diff --git a/torch/_lobpcg.py b/torch/_lobpcg.py index b40849528cf..03fe16f470c 100644 --- a/torch/_lobpcg.py +++ b/torch/_lobpcg.py @@ -795,9 +795,9 @@ class LOBPCG: # strict ordering of eigenpairs break count += 1 - assert ( - count >= prev_count - ), f"the number of converged eigenpairs (was {prev_count}, got {count}) cannot decrease" + assert count >= prev_count, ( + f"the number of converged eigenpairs (was {prev_count}, got {count}) cannot decrease" + ) self.ivars["converged_count"] = count self.tvars["rerr"] = rerr return count diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 51e098ff061..caf0fc47bfa 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -541,9 +541,9 @@ def meta_sparse_structured_linear( transposed_strides = (1, input.size(0)) if out_dtype is not None: - assert ( - input.dtype == torch.int8 and out_dtype == torch.int32 - ), "out_dtype is only supported for i8i8->i32 linear operator" + assert input.dtype == torch.int8 and out_dtype == torch.int32, ( + "out_dtype is only supported for i8i8->i32 linear operator" + ) output = input.new_empty( output_sizes, dtype=input.dtype if out_dtype is None else out_dtype, @@ -566,9 +566,9 @@ def meta_sparse_structured_mm( output_sizes = [mat1.size(0), mat2.size(1)] if out_dtype is not None: - assert ( - mat2.dtype == torch.int8 and out_dtype == torch.int32 - ), "out_dtype is only supported for i8i8->i32 linear operator" + assert mat2.dtype == torch.int8 and out_dtype == torch.int32, ( + "out_dtype is only supported for i8i8->i32 linear operator" + ) output = mat2.new_empty( output_sizes, dtype=mat2.dtype if out_dtype is None else out_dtype, @@ -588,22 +588,22 @@ def meta_sparse_structured_addmm( beta=1, out_dtype: Optional[torch.dtype] = None, ): - assert ( - len(input.shape) == 1 - ), "only input broadcasted to columns of mat1 * mat2 product is supported" + assert len(input.shape) == 1, ( + "only input broadcasted to columns of mat1 * mat2 product is supported" + ) assert len(mat1.shape) == 2 assert len(mat1_meta.shape) == 2 assert len(mat2.shape) == 2 - assert input.size(0) == mat1.size( - 0 - ), "only input broadcasted to columns of mat1 * mat2 product is supported" + assert input.size(0) == mat1.size(0), ( + "only input broadcasted to columns of mat1 * mat2 product is supported" + ) assert mat1.size(1) == mat2.size(0) / 2 output_sizes = [mat1.size(0), mat2.size(1)] if out_dtype is not None: - assert ( - mat2.dtype == torch.int8 and out_dtype == torch.int32 - ), "out_dtype is only supported for i8i8->i32 linear operator" + assert mat2.dtype == torch.int8 and out_dtype == torch.int32, ( + "out_dtype is only supported for i8i8->i32 linear operator" + ) output = mat2.new_empty( output_sizes, dtype=mat2.dtype if out_dtype is None else out_dtype, @@ -638,9 +638,9 @@ def meta__cslt_sparse_mm( compression_factor = 10 if is_8bit_input_type else 9 if is_8bit_input_type: - assert ( - not dense_B.is_contiguous() - ), "dense input must be transposed for 8bit dtypes" + assert not dense_B.is_contiguous(), ( + "dense input must be transposed for 8bit dtypes" + ) k = dense_B.size(0) n = dense_B.size(1) @@ -649,16 +649,14 @@ def meta__cslt_sparse_mm( assert m == bias.size(0) if out_dtype is not None: - assert ( - is_8bit_input_type - and out_dtype - in { - torch.float16, - torch.bfloat16, - torch.int32, - torch.float8_e4m3fn, - } - ), "out_dtype is not supported for {compressed_A.dtype} x {dense_B.dtype} -> {out_dtype} matmul!" + assert is_8bit_input_type and out_dtype in { + torch.float16, + torch.bfloat16, + torch.int32, + torch.float8_e4m3fn, + }, ( + "out_dtype is not supported for {compressed_A.dtype} x {dense_B.dtype} -> {out_dtype} matmul!" + ) output_shape = (n, m) if transpose_result else (m, n) return dense_B.new_empty(output_shape, dtype=out_dtype) @@ -861,12 +859,12 @@ def functional_assert_async_meta(val, assert_msg, dep_token): # From aten/src/ATen/native/LinearAlgebraUtils.h def squareCheckInputs(self: Tensor, f_name: str): - assert ( - self.dim() >= 2 - ), f"{f_name}: The input tensor must have at least 2 dimensions." - assert ( - self.size(-1) == self.size(-2) - ), f"{f_name}: A must be batches of square matrices, but they are {self.size(-2)} by {self.size(-1)} matrices" + assert self.dim() >= 2, ( + f"{f_name}: The input tensor must have at least 2 dimensions." + ) + assert self.size(-1) == self.size(-2), ( + f"{f_name}: A must be batches of square matrices, but they are {self.size(-2)} by {self.size(-1)} matrices" + ) # Validates input shapes and devices @@ -6500,9 +6498,9 @@ def topk_meta(self, k, dim=-1, largest=True, sorted=True): def meta__segment_reduce_backward( grad, output, data, reduce, lengths=None, offsets=None, axis=0, initial=None ): - assert ( - lengths is not None or offsets is not None - ), "segment_reduce(): Either lengths or offsets must be defined" + assert lengths is not None or offsets is not None, ( + "segment_reduce(): Either lengths or offsets must be defined" + ) data_contig = data.contiguous() grad_contig = grad.contiguous() return torch.empty_like( @@ -6579,7 +6577,9 @@ def linear_backward(input_, grad_output_, weight_, output_mask): def meta_pixel_shuffle(self, upscale_factor): assert ( len(self.shape) > 2 and self.shape[-3] % (upscale_factor * upscale_factor) == 0 - ), f"Invalid input shape for pixel_shuffle: {self.shape} with upscale_factor = {upscale_factor}" + ), ( + f"Invalid input shape for pixel_shuffle: {self.shape} with upscale_factor = {upscale_factor}" + ) def is_channels_last(ten): return torch._prims_common.suggest_memory_format(ten) == torch.channels_last @@ -6753,15 +6753,14 @@ def nan_to_num(self, nan=None, posinf=None, neginf=None): @register_meta(torch.ops.aten.transpose_) def transpose_(self, dim0, dim1): - assert ( - self.layout - not in { - torch.sparse_csr, - torch.sparse_csc, - torch.sparse_bsr, - torch.sparse_bsc, - } - ), f"torch.transpose_: in-place transposition is not supported for {self.layout} layout" + assert self.layout not in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + }, ( + f"torch.transpose_: in-place transposition is not supported for {self.layout} layout" + ) ndims = self.ndim @@ -6788,13 +6787,14 @@ def t_(self): if self.is_sparse: sparse_dim = self.sparse_dim() dense_dim = self.dense_dim() - assert ( - sparse_dim <= 2 and dense_dim == 0 - ), f"t_ expects a tensor with <= 2 sparse and 0 dense dimensions, but got {sparse_dim} sparse and {dense_dim} dense dimensions" # noqa: B950 + assert sparse_dim <= 2 and dense_dim == 0, ( + f"t_ expects a tensor with <= 2 sparse and 0 dense dimensions, " + f"but got {sparse_dim} sparse and {dense_dim} dense dimensions" + ) else: - assert ( - self.dim() <= 2 - ), f"t_ expects a tensor with <= 2 dimensions, but self is {ndims}D" + assert self.dim() <= 2, ( + f"t_ expects a tensor with <= 2 dimensions, but self is {ndims}D" + ) return transpose_(self, 0, 0 if ndims < 2 else 1) diff --git a/torch/_ops.py b/torch/_ops.py index 84b48ed1dec..c6f5be583e4 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -133,9 +133,9 @@ class OperatorBase: return fn assert isinstance(k, DispatchKey) - assert ( - k != DispatchKey.Python - ), "Please register a mode for the DispatchKey.Python key instead." + assert k != DispatchKey.Python, ( + "Please register a mode for the DispatchKey.Python key instead." + ) if k in self.py_kernels: raise RuntimeError( @@ -422,12 +422,12 @@ class HigherOrderOperator(OperatorBase, abc.ABC): DispatchKey.Python ): curr_mode = _get_current_dispatch_mode_pre_dispatch() - assert ( - curr_mode is not None - ), "Illegal invocation of dispatch on DispatchKey.PreDispatch without a mode." - assert ( - type(curr_mode) in self.python_key_table - ), f"Current active mode {curr_mode} not registered" + assert curr_mode is not None, ( + "Illegal invocation of dispatch on DispatchKey.PreDispatch without a mode." + ) + assert type(curr_mode) in self.python_key_table, ( + f"Current active mode {curr_mode} not registered" + ) handler = self.python_key_table[type(curr_mode)] with _pop_mode_temporarily(functionality_key) as mode: return handler(mode, *args, **kwargs) @@ -828,9 +828,9 @@ class OpOverload(OperatorBase): # TODO: We also need to handle tensor subclasses here # TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now. curr_mode = type(_get_current_dispatch_mode()) - assert ( - curr_mode is not None - ), "Illegal invocation of dispatch on DispatchKey.Python without a mode." + assert curr_mode is not None, ( + "Illegal invocation of dispatch on DispatchKey.Python without a mode." + ) if curr_mode not in self.python_key_table: if isinstance(self, TorchBindOpOverload): diff --git a/torch/_utils.py b/torch/_utils.py index f227042803f..6346d024c79 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -81,9 +81,9 @@ def _to(self, device, non_blocking=False): return untyped_storage device_module = getattr(torch, device.type, None) - assert ( - device_module is not None - ), f"{device.type.upper()} device module is not loaded" + assert device_module is not None, ( + f"{device.type.upper()} device module is not loaded" + ) with device_module.device(device): if self.is_sparse and hasattr(device_module, "sparse"): new_type = getattr(device_module.sparse, self.__class__.__name__) @@ -95,9 +95,9 @@ def _to(self, device, non_blocking=False): ) return new_type(indices, values, self.size()) else: - assert ( - not self.is_sparse - ), f"sparse storage is not supported for {device.type.upper()} tensors" + assert not self.is_sparse, ( + f"sparse storage is not supported for {device.type.upper()} tensors" + ) untyped_storage = torch.UntypedStorage(self.size(), device=device) untyped_storage.copy_(self, non_blocking) return untyped_storage diff --git a/torch/onnx/_internal/diagnostics/_diagnostic.py b/torch/onnx/_internal/diagnostics/_diagnostic.py index a8110ccf4b0..9ee564e9b13 100644 --- a/torch/onnx/_internal/diagnostics/_diagnostic.py +++ b/torch/onnx/_internal/diagnostics/_diagnostic.py @@ -165,18 +165,18 @@ _context = engine.background_context @contextlib.contextmanager -def create_export_diagnostic_context() -> ( - Generator[infra.DiagnosticContext, None, None] -): +def create_export_diagnostic_context() -> Generator[ + infra.DiagnosticContext, None, None +]: """Create a diagnostic context for export. This is a workaround for code robustness since diagnostic context is accessed by export internals via global variable. See `ExportDiagnosticEngine` for more details. """ global _context - assert ( - _context == engine.background_context - ), "Export context is already set. Nested export is not supported." + assert _context == engine.background_context, ( + "Export context is already set. Nested export is not supported." + ) _context = engine.create_diagnostic_context( "torch.onnx.export", torch.__version__, diff --git a/torch/onnx/_internal/exporter/_building.py b/torch/onnx/_internal/exporter/_building.py index 9a01570c2f4..3aa29e53f91 100644 --- a/torch/onnx/_internal/exporter/_building.py +++ b/torch/onnx/_internal/exporter/_building.py @@ -98,9 +98,9 @@ def _construct_named_inputs_and_attrs( else: # Handle attributes attribute: ValidAttributeType | ir.Attr - assert isinstance( - param, _schemas.AttributeParameter - ), f"Expected AttributeParameter, got {type(param)}" + assert isinstance(param, _schemas.AttributeParameter), ( + f"Expected AttributeParameter, got {type(param)}" + ) if reversed_args_stack: # First exhaust the positional arguments attribute = reversed_args_stack.pop() # type: ignore[assignment] @@ -166,9 +166,9 @@ def _resolve_parameter_dtypes( type_binding = {} for name, arg in named_inputs.items(): param = signature.params_map[name] - assert isinstance( - param, _schemas.Parameter - ), f"Expected Parameter, got {type(param)}" + assert isinstance(param, _schemas.Parameter), ( + f"Expected Parameter, got {type(param)}" + ) if isinstance(arg, (int, float, bool, str, Sequence, torch.Tensor)): # Skip the Python constants because we do not know what dtype they should take yet continue @@ -318,9 +318,9 @@ def _process_python_constants( # - Otherwise, set named_inputs[param.name] = Constant(value) for name, arg in named_inputs.items(): param = signature.params_map[name] - assert isinstance( - param, _schemas.Parameter - ), f"Expected Parameter, got {type(param)}" + assert isinstance(param, _schemas.Parameter), ( + f"Expected Parameter, got {type(param)}" + ) if isinstance(arg, ir.Value): # TODO(justinchuby): Cast the ir.Value here if needed @@ -384,9 +384,9 @@ def _process_python_sequences( """ for name, arg in named_inputs.items(): param = signature.params_map[name] - assert isinstance( - param, _schemas.Parameter - ), f"Expected Parameter, got {type(param)}" + assert isinstance(param, _schemas.Parameter), ( + f"Expected Parameter, got {type(param)}" + ) if not isinstance(arg, (tuple, list)): continue @@ -447,9 +447,9 @@ def _process_python_sequences( ) else: # Turn the Python constant into 1D tensor for the constant - assert isinstance( - val, (bool, int, float) - ), f"Expected int or float, got {type(val)}" + assert isinstance(val, (bool, int, float)), ( + f"Expected int or float, got {type(val)}" + ) new_args.append( _get_or_create_constant(constant_farm, [val], dtype, opset) # type: ignore[arg-type] ) diff --git a/torch/onnx/_internal/exporter/_core.py b/torch/onnx/_internal/exporter/_core.py index 1e1c13b8fd5..42d60cc9c28 100644 --- a/torch/onnx/_internal/exporter/_core.py +++ b/torch/onnx/_internal/exporter/_core.py @@ -324,9 +324,9 @@ def _handle_getitem_node( assert len(node.all_input_nodes) == 1 source = node.all_input_nodes[0] source_outputs = node_name_to_values[source.name] - assert isinstance( - source_outputs, Sequence - ), f"Expected {source.name} to output sequence, got {node_name_to_values[source.name]}" + assert isinstance(source_outputs, Sequence), ( + f"Expected {source.name} to output sequence, got {node_name_to_values[source.name]}" + ) index = typing.cast(int, node.args[1]) value = source_outputs[index] # Save the getitem value to the values mapping to in case @@ -649,9 +649,9 @@ def _handle_output_node( # for example, a subgraph has multiple outputs. We flatten them all as ONNX graph outputs for output in node.args[0]: # type: ignore[index,union-attr] output_value_name = output.name # type: ignore[union-attr] - assert isinstance( - output_value_name, str - ), f"Bug: Expected {output_value_name!r} to be a string" + assert isinstance(output_value_name, str), ( + f"Bug: Expected {output_value_name!r} to be a string" + ) values = node_name_to_values[output_value_name] if isinstance(values, Sequence): graph_like.outputs.extend(values) @@ -754,9 +754,9 @@ def _get_inputs_and_attributes( return inputs, {}, [], [node.name] # type: ignore[return-value] # The target should be an ATen operator now - assert hasattr( - node.target, "_schema" - ), f"The target should be an ATen operator now, but node target {node.target} has no schema" + assert hasattr(node.target, "_schema"), ( + f"The target should be an ATen operator now, but node target {node.target} has no schema" + ) node_schema: torch.FunctionSchema = node.target._schema # This function assumes the order of arguments in FX op is the @@ -1050,9 +1050,9 @@ def _exported_program_to_onnx_program( persistent = spec.persistent value = values[value_name] - assert not isinstance( - value, Sequence - ), f"Input '{value_name}' should not be a sequence. This is unexpected." + assert not isinstance(value, Sequence), ( + f"Input '{value_name}' should not be a sequence. This is unexpected." + ) value.metadata_props["pkg.torch.export.graph_signature.InputSpec.kind"] = ( input_kind.name diff --git a/torch/onnx/_internal/exporter/_schemas.py b/torch/onnx/_internal/exporter/_schemas.py index a8653a2e6ef..3aa8b0e0c7e 100644 --- a/torch/onnx/_internal/exporter/_schemas.py +++ b/torch/onnx/_internal/exporter/_schemas.py @@ -307,9 +307,9 @@ def _get_allowed_types_from_type_annotation( allowed_types = set() subtypes = typing.get_args(type_) for subtype in subtypes: - assert ( - subtype is not type(None) - ), "Union should not contain None type because it is handled by _is_optional." + assert subtype is not type(None), ( + "Union should not contain None type because it is handled by _is_optional." + ) allowed_types.update(_get_allowed_types_from_type_annotation(subtype)) return allowed_types diff --git a/torch/onnx/_internal/fx/fx_onnx_interpreter.py b/torch/onnx/_internal/fx/fx_onnx_interpreter.py index ddb0433b9d1..8f5b646c3df 100644 --- a/torch/onnx/_internal/fx/fx_onnx_interpreter.py +++ b/torch/onnx/_internal/fx/fx_onnx_interpreter.py @@ -602,9 +602,9 @@ class FxOnnxInterpreter: raise RuntimeError( f"Unsupported type(node.meta['val']) for placeholder: {type(fake_tensor)}" ) - assert ( - output is not None - ), f"Node creates None with target={node.target} and name={node.name}" + assert output is not None, ( + f"Node creates None with target={node.target} and name={node.name}" + ) assert isinstance(output, onnxscript_graph_building.TorchScriptTensor) assert isinstance(output, onnxscript.tensor.Tensor) @@ -631,9 +631,9 @@ class FxOnnxInterpreter: onnx_tensor_tuple = fx_name_to_onnxscript_value[node.args[0].name] # type: ignore[union-attr,index] index = node.args[1] value = onnx_tensor_tuple[index] # type: ignore[index] - assert ( - value is not None - ), f"Node creates None with target={node.target} and name={node.name}" + assert value is not None, ( + f"Node creates None with target={node.target} and name={node.name}" + ) assert isinstance( value, (onnxscript_graph_building.TorchScriptTensor, tuple) ), type(value) @@ -664,9 +664,9 @@ class FxOnnxInterpreter: onnxscript_graph_building.TorchScriptTensor | tuple[onnxscript_graph_building.TorchScriptTensor, ...] ) = symbolic_fn(*onnx_args, **onnx_kwargs) - assert ( - output is not None - ), f"Node creates None with target={node.target}, name={node.name}, args={onnx_args}, kwargs={onnx_kwargs}" + assert output is not None, ( + f"Node creates None with target={node.target}, name={node.name}, args={onnx_args}, kwargs={onnx_kwargs}" + ) # Assign type and shape from fx graph. _fill_tensor_shape_type(output, node.name, node.meta["val"]) # One fx node could produce multiple outputs (e.g., tuple of tensors); in @@ -694,9 +694,9 @@ class FxOnnxInterpreter: # tensor, etc), we flatten the collection and register each element as output. flat_args, _ = _pytree.tree_flatten(node.args[0]) for arg in flat_args: - assert isinstance( - arg, torch.fx.Node - ), f"arg must be a torch.fx.Node, not {type(arg)}" + assert isinstance(arg, torch.fx.Node), ( + f"arg must be a torch.fx.Node, not {type(arg)}" + ) onnx_tensor_or_tensor_tuple = fx_name_to_onnxscript_value[arg.name] onnxscript_graph.register_outputs(onnx_tensor_or_tensor_tuple) @@ -735,15 +735,15 @@ class FxOnnxInterpreter: root_fx_graph_module: The root FX module. onnxfunction_dispatcher: The dispatcher. """ - assert isinstance( - node.target, str - ), f"node.target must be a str, not {type(node.target)} for node {node}." + assert isinstance(node.target, str), ( + f"node.target must be a str, not {type(node.target)} for node {node}." + ) sub_module = root_fx_graph_module.get_submodule(node.target) - assert isinstance( - sub_module, torch.fx.GraphModule - ), f"sub_module must be a torch.fx.GraphModule, not {type(sub_module)} for node {node}." + assert isinstance(sub_module, torch.fx.GraphModule), ( + f"sub_module must be a torch.fx.GraphModule, not {type(sub_module)} for node {node}." + ) sub_onnxscript_graph = self.run( sub_module, onnxfunction_dispatcher, parent_onnxscript_graph diff --git a/torch/onnx/_internal/fx/passes/readability.py b/torch/onnx/_internal/fx/passes/readability.py index a27bbe3323f..83993cf25d2 100644 --- a/torch/onnx/_internal/fx/passes/readability.py +++ b/torch/onnx/_internal/fx/passes/readability.py @@ -41,9 +41,9 @@ class RestoreParameterAndBufferNames(_pass.Transform): ) -> None: """Rename the parameter/buffer and replace corresponding nodes with new nodes of updated target.""" assert len(nodes) > 0, "`nodes` cannot be empty" - assert ( - len({node.target for node in nodes}) == 1 - ), "`nodes` must all have same `target`" + assert len({node.target for node in nodes}) == 1, ( + "`nodes` must all have same `target`" + ) old_name = nodes[0].target assert isinstance(old_name, str), f"Expected str, got type({old_name})" # Parameter/buffer name cannot contain "." @@ -74,9 +74,9 @@ class RestoreParameterAndBufferNames(_pass.Transform): to the same objects, allowing us to use it as key to retrieve the original name. """ assert len(args) == 0, "RestoreParameterAndBufferNames does not take any args" - assert ( - len(kwargs) == 0 - ), "RestoreParameterAndBufferNames does not take any kwargs" + assert len(kwargs) == 0, ( + "RestoreParameterAndBufferNames does not take any kwargs" + ) # state_to_readable_name[parameter/buffer] returns the original readable name of # the parameter/buffer. E.g., "self.linear.weight". state_to_readable_name: dict[torch.nn.Parameter | torch.Tensor, str] = {} @@ -95,9 +95,9 @@ class RestoreParameterAndBufferNames(_pass.Transform): for node in self.module.graph.nodes: if node.op == "get_attr": - assert isinstance( - node.target, str - ), f"Expected str, got type({node.target})" + assert isinstance(node.target, str), ( + f"Expected str, got type({node.target})" + ) if node.target.find(".") != -1: raise RuntimeError( f"Unexpected target {node.target} in get_attr, found '.' in target. " diff --git a/torch/onnx/_internal/fx/passes/type_promotion.py b/torch/onnx/_internal/fx/passes/type_promotion.py index 67220fadd17..f0878493e99 100644 --- a/torch/onnx/_internal/fx/passes/type_promotion.py +++ b/torch/onnx/_internal/fx/passes/type_promotion.py @@ -280,9 +280,9 @@ class ReductionTypePromotionRule(TypePromotionRule): def preview_type_promotion( self, args: tuple, kwargs: dict ) -> TypePromotionSnapshot: - assert ( - len(args) >= 1 - ), f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument" + assert len(args) >= 1, ( + f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument" + ) arg = args[0] assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor" dtype: torch.dtype | None = kwargs.get("dtype", None) @@ -319,9 +319,9 @@ class AllOrAnyReductionTypePromotionRule(ReductionTypePromotionRule): def preview_type_promotion( self, args: tuple, kwargs: dict ) -> TypePromotionSnapshot: - assert ( - len(args) >= 1 - ), f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument" + assert len(args) >= 1, ( + f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument" + ) arg = args[0] assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor" computation_dtype = torch.bool @@ -344,9 +344,9 @@ class SumLikeReductionTypePromotionRule(ReductionTypePromotionRule): def preview_type_promotion( self, args: tuple, kwargs: dict ) -> TypePromotionSnapshot: - assert ( - len(args) >= 1 - ), f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument" + assert len(args) >= 1, ( + f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument" + ) arg = args[0] assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor" dtype: torch.dtype | None = kwargs.get("dtype", None) @@ -1319,17 +1319,17 @@ def find_compatible_op_overload( op_trace_dispatch_mode = _OpTraceDispatchMode() with op_trace_dispatch_mode: op(*args, **kwargs) - assert ( - len(op_trace_dispatch_mode.traced_ops) >= 1 - ), "Expected at least 1 traced op, got 0" + assert len(op_trace_dispatch_mode.traced_ops) >= 1, ( + "Expected at least 1 traced op, got 0" + ) new_op_overload = op_trace_dispatch_mode.traced_ops[0] - assert isinstance( - new_op_overload, torch._ops.OpOverload - ), f"Expected OpOverload, got {type(new_op_overload)}" - assert ( - new_op_overload.overloadpacket == op - ), f"Expected same OpOverload packet, got {new_op_overload.overloadpacket} != {op}" + assert isinstance(new_op_overload, torch._ops.OpOverload), ( + f"Expected OpOverload, got {type(new_op_overload)}" + ) + assert new_op_overload.overloadpacket == op, ( + f"Expected same OpOverload packet, got {new_op_overload.overloadpacket} != {op}" + ) return new_op_overload @@ -1398,9 +1398,9 @@ class _TypePromotionInterpreter(torch.fx.Interpreter): assert node_val is not None, f"Node {node} node.meta['val'] is not set." args, kwargs = self.fetch_args_kwargs_from_env(node) target = node.target - assert isinstance( - target, torch._ops.OpOverload - ), f"Expected OpOverload, got {type(target)}" + assert isinstance(target, torch._ops.OpOverload), ( + f"Expected OpOverload, got {type(target)}" + ) node.target = find_compatible_op_overload(target.overloadpacket, args, kwargs) new_node_val = self._run_node_and_set_meta(node) diff --git a/torch/onnx/_internal/fx/passes/virtualization.py b/torch/onnx/_internal/fx/passes/virtualization.py index 456c25fee77..504dea1d842 100644 --- a/torch/onnx/_internal/fx/passes/virtualization.py +++ b/torch/onnx/_internal/fx/passes/virtualization.py @@ -48,9 +48,9 @@ class ReplaceGetAttrWithPlaceholder(_pass.Transform): @property def replaced_attrs(self) -> tuple[torch.Tensor, ...]: """The list of replaced weight tensors.""" - assert ( - self._replaced_attrs is not None - ), "Must run ReplaceGetAttrWithPlaceholder first" + assert self._replaced_attrs is not None, ( + "Must run ReplaceGetAttrWithPlaceholder first" + ) return self._replaced_attrs def _run(self, *args, **kwargs) -> torch.fx.GraphModule: diff --git a/torch/onnx/_internal/io_adapter.py b/torch/onnx/_internal/io_adapter.py index 1d7d5806cb0..f93e68ce5a1 100644 --- a/torch/onnx/_internal/io_adapter.py +++ b/torch/onnx/_internal/io_adapter.py @@ -639,9 +639,9 @@ class PrependParamsAndBuffersAotAutogradOutputStep(OutputAdaptStep): flattened_outputs: The flattened model outputs. """ - assert isinstance( - model, torch_export.ExportedProgram - ), "'model' must be torch_export.ExportedProgram" + assert isinstance(model, torch_export.ExportedProgram), ( + "'model' must be torch_export.ExportedProgram" + ) ordered_buffers = tuple( model.state_dict[name] if name in model.state_dict diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index dbf6beb648d..f609b4452bb 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -411,9 +411,9 @@ def quantized_args( output = fn(g, *non_quantized_args, **kwargs) assert _scale is not None, "Bug: Scale must be set for quantized operator" - assert ( - _zero_point is not None - ), "Bug: Zero point must be set for quantized operator" + assert _zero_point is not None, ( + "Bug: Zero point must be set for quantized operator" + ) if quantize_output: return quantize_helper(g, output, _scale, _zero_point) diff --git a/torch/onnx/symbolic_opset14.py b/torch/onnx/symbolic_opset14.py index ae33ddf58c6..8bc6f0f9f4d 100644 --- a/torch/onnx/symbolic_opset14.py +++ b/torch/onnx/symbolic_opset14.py @@ -145,10 +145,12 @@ def scaled_dot_product_attention( scale: torch._C.Value | None = None, enable_gqa: bool = False, ): - assert (not is_causal) or ( - is_causal and symbolic_helper._is_none(attn_mask) - ), "is_causal and attn_mask cannot be set at the same time" - assert not enable_gqa, "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" + assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), ( + "is_causal and attn_mask cannot be set at the same time" + ) + assert not enable_gqa, ( + "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" + ) if symbolic_helper._is_none(scale): scale = _attention_scale(g, query) diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index c73b5a1b23a..371745664f4 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -1199,9 +1199,9 @@ def prelu(g: jit_utils.GraphContext, self, weight): weight_rank = 0 if self_rank is not None and weight_rank is not None: - assert ( - self_rank >= weight_rank - ), f"rank(x) should be >= rank(slope) but got {self_rank} < {weight_rank}" + assert self_rank >= weight_rank, ( + f"rank(x) should be >= rank(slope) but got {self_rank} < {weight_rank}" + ) return g.op("PRelu", self, weight) @@ -4056,9 +4056,9 @@ def repeat_interleave( "Unsupported for cases with dynamic repeats", self, ) - assert ( - repeats_sizes[0] == input_sizes[dim] - ), "repeats must have the same size as input along dim" + assert repeats_sizes[0] == input_sizes[dim], ( + "repeats must have the same size as input along dim" + ) reps = repeats_sizes[0] else: raise errors.SymbolicValueError("repeats must be 0-dim or 1-dim tensor", self) diff --git a/torch/onnx/verification.py b/torch/onnx/verification.py index 049195a4bf1..1a8ae659a63 100644 --- a/torch/onnx/verification.py +++ b/torch/onnx/verification.py @@ -229,9 +229,9 @@ def _compare_onnx_pytorch_outputs_in_np( pt_outs: _OutputsType, options: VerificationOptions, ): - assert ( - len(onnx_outs) == len(pt_outs) - ), f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs)})" + assert len(onnx_outs) == len(pt_outs), ( + f"Number of outputs differ ONNX runtime: ({len(onnx_outs)}) PyTorch: ({len(pt_outs)})" + ) acceptable_error_percentage = options.acceptable_error_percentage if acceptable_error_percentage and ( acceptable_error_percentage > 1.0 or acceptable_error_percentage < 0.0 @@ -1561,9 +1561,9 @@ class GraphInfo: pt_outs = self.pt_outs graph_outputs = list(self.graph.outputs()) assert pt_outs is not None - assert len(graph_outputs) == len( - pt_outs - ), f"{len(graph_outputs)} vs {len(pt_outs)}\nGraph: {self.graph}" + assert len(graph_outputs) == len(pt_outs), ( + f"{len(graph_outputs)} vs {len(pt_outs)}\nGraph: {self.graph}" + ) return {v.debugName(): o for v, o in zip(graph_outputs, pt_outs)} def _args_and_params_for_partition_graph( @@ -1577,9 +1577,9 @@ class GraphInfo: args = tuple(bridge_kwargs[k] for k in input_names if k in bridge_kwargs) args += tuple(full_kwargs[k] for k in input_names if k in full_kwargs) params = {k: full_params[k] for k in input_names if k in full_params} - assert len(args) + len(params) == len( - input_names - ), f"{len(args)} + {len(params)} vs {len(input_names)}: {input_names}" + assert len(args) + len(params) == len(input_names), ( + f"{len(args)} + {len(params)} vs {len(input_names)}: {input_names}" + ) return args, params def verify_export( diff --git a/torch/overrides.py b/torch/overrides.py index 97fc86af920..a76c0a2ffcf 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -1805,9 +1805,9 @@ has_torch_function_variadic = _add_docstr( @functools.lru_cache(None) -def _get_overridable_functions() -> ( - tuple[dict[Any, list[Callable]], dict[Callable, str]] -): +def _get_overridable_functions() -> tuple[ + dict[Any, list[Callable]], dict[Callable, str] +]: overridable_funcs = collections.defaultdict(list) index = {} tested_namespaces = [ diff --git a/torch/serialization.py b/torch/serialization.py index 514cd318c1c..f5f7b652496 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -2023,9 +2023,9 @@ def _load( typename = _maybe_decode_ascii(saved_id[0]) data = saved_id[1:] - assert ( - typename == "storage" - ), f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" + assert typename == "storage", ( + f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" + ) storage_type, key, location, numel = data if storage_type is torch.UntypedStorage: dtype = torch.uint8 diff --git a/torchgen/_autoheuristic/ah_tree.py b/torchgen/_autoheuristic/ah_tree.py index 77b8405fe71..c2ec2b8d947 100644 --- a/torchgen/_autoheuristic/ah_tree.py +++ b/torchgen/_autoheuristic/ah_tree.py @@ -208,9 +208,9 @@ class DecisionTree: if name in dummy_col_2_col_val: (orig_name, value) = dummy_col_2_col_val[name] predicate = f"{indent}if str(context.get_value('{orig_name}')) != '{value}':" - assert ( - threshold == 0.5 - ), f"expected threshold to be 0.5 but is {threshold}" + assert threshold == 0.5, ( + f"expected threshold to be 0.5 but is {threshold}" + ) else: predicate = ( f"{indent}if context.get_value('{name}') <= {threshold}:" diff --git a/torchgen/_autoheuristic/train_decision.py b/torchgen/_autoheuristic/train_decision.py index 9fd50047461..853509e42ff 100644 --- a/torchgen/_autoheuristic/train_decision.py +++ b/torchgen/_autoheuristic/train_decision.py @@ -446,9 +446,9 @@ class AHTrainDecisionTree(AHTrain): for row in group.itertuples(): choice2time[row.choice] = row.median_execution_time - assert ( - len(unique_choices) == len(group) - ), f"len(unique_choices) != len(group): {len(unique_choices)} != {len(group)}" + assert len(unique_choices) == len(group), ( + f"len(unique_choices) != len(group): {len(unique_choices)} != {len(group)}" + ) return pd.Series( { @@ -869,9 +869,9 @@ class DecisionEvaluator: top_k_choices = self.top_k_classes( self.model, prob, k=self.k, avail_choices=avail_choices ) - assert ( - true in avail_choices - ), f"Best choice {true} not in available choices {avail_choices}" + assert true in avail_choices, ( + f"Best choice {true} not in available choices {avail_choices}" + ) default_config = self.train.get_default_config(self.df.iloc[i]) self.eval_prediction( avail_choices, diff --git a/torchgen/_autoheuristic/train_regression.py b/torchgen/_autoheuristic/train_regression.py index 2d5b012915e..76e7d952978 100644 --- a/torchgen/_autoheuristic/train_regression.py +++ b/torchgen/_autoheuristic/train_regression.py @@ -405,9 +405,9 @@ class AHTrainRegressionTree(AHTrain): if name in dummy_col_2_col_val: (orig_name, value) = dummy_col_2_col_val[name] predicate = f"{indent}if str(context.get_value('{orig_name}')) != '{value}':" - assert ( - threshold == 0.5 - ), f"expected threshold to be 0.5 but is {threshold}" + assert threshold == 0.5, ( + f"expected threshold to be 0.5 but is {threshold}" + ) else: predicate = ( f"{indent}if context.get_value('{name}') <= {threshold}:" diff --git a/torchgen/api/cpp.py b/torchgen/api/cpp.py index 7695367a030..554bfa4a5c7 100644 --- a/torchgen/api/cpp.py +++ b/torchgen/api/cpp.py @@ -250,7 +250,9 @@ def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType: elif t.name == BaseTy.Scalar: return BaseCType(scalarT) elif isinstance(t, ListType): - assert not mutable, "Native functions should never return a mutable tensor list. They should return void." + assert not mutable, ( + "Native functions should never return a mutable tensor list. They should return void." + ) elem = returntype_type(t.elem, mutable=False) assert t.size is None, f"fixed size list returns not supported: {t}" return VectorCType(elem) diff --git a/torchgen/api/lazy.py b/torchgen/api/lazy.py index 1422a5d7383..1d308afd813 100644 --- a/torchgen/api/lazy.py +++ b/torchgen/api/lazy.py @@ -237,9 +237,9 @@ class LazyArgument: @property def lazy_type(self) -> CType: - assert ( - self.lazy_type_ is not None - ), f"Attempted to access lazy_type for invalid argument {self.name}" + assert self.lazy_type_ is not None, ( + f"Attempted to access lazy_type for invalid argument {self.name}" + ) return self.lazy_type_ @@ -374,9 +374,9 @@ class LazyIrSchema: curr_args = curr_args.all() for arg in curr_args: if isGeneratorType(arg.type): - assert ( - self.generator_arg is None - ), "We expect there is only one generator arg" + assert self.generator_arg is None, ( + "We expect there is only one generator arg" + ) self.generator_arg = NamedCType( arg.name, arg.type, # type:ignore[arg-type] diff --git a/torchgen/api/types/signatures.py b/torchgen/api/types/signatures.py index 2b9b9c27f69..384eeeb8e48 100644 --- a/torchgen/api/types/signatures.py +++ b/torchgen/api/types/signatures.py @@ -408,7 +408,9 @@ def kernel_signature( meta = backend_index.get_kernel(f) symint = meta is not None and meta.supports_symint() if symint: - assert f.func.has_symint(), f"attempted to define symint kernel for {backend_index.dispatch_key} without SymInt in schema" + assert f.func.has_symint(), ( + f"attempted to define symint kernel for {backend_index.dispatch_key} without SymInt in schema" + ) if backend_index.external: return DispatcherSignature.from_schema(f.func, prefix=prefix, symint=symint) else: diff --git a/torchgen/dest/lazy_ir.py b/torchgen/dest/lazy_ir.py index 8f260cc923a..b912b8f2427 100644 --- a/torchgen/dest/lazy_ir.py +++ b/torchgen/dest/lazy_ir.py @@ -481,9 +481,9 @@ class GenLazyNativeFuncDefinition: optional_devices = [ a.name for a in scalar_args if a.lazy_type == optional_device ] - assert ( - len(value_types_names) > 0 or len(optional_devices) > 0 - ), "Expected at least one Value or Device type" + assert len(value_types_names) > 0 or len(optional_devices) > 0, ( + "Expected at least one Value or Device type" + ) get_device_str = ( f"{self.get_device_fn}({', '.join(value_types_names + optional_devices)})" ) @@ -580,9 +580,9 @@ std::vector shapes{torch::lazy::Shape(out_meta.scalar_type() # xla uses an instance method for tensor creation, for the time being if self.create_from_first_tensor: # TODO(whc) remove this if XLA switches to using static method for creation - assert ( - first_tensor_name is not None - ), "Requires first tensor to create lazy tensor" + assert first_tensor_name is not None, ( + "Requires first tensor to create lazy tensor" + ) return f"{first_tensor_name}.{self.create_tensor}" return f"{self.backend_namespace}::{self.create_tensor}" @@ -595,9 +595,9 @@ std::vector shapes{torch::lazy::Shape(out_meta.scalar_type() {self.create_lazy_tensor(first_tensor_name)}(std::move(node), *common_device));""" if returns_length > 1: - assert ( - len(value_types_names) > 0 - ), "Code below assumes there is at least one tensor arg" + assert len(value_types_names) > 0, ( + "Code below assumes there is at least one tensor arg" + ) bridge_str = f"""std::vector<{self.lazy_tensor_ptr}> lazy_tensors; for (int i = 0; i < {returns_length}; i++) {{ lazy_tensors.push_back({self.create_lazy_tensor(first_tensor_name)}({getValueT()}(node, i), *common_device)); diff --git a/torchgen/dest/ufunc.py b/torchgen/dest/ufunc.py index e66c9a4e526..832316d018e 100644 --- a/torchgen/dest/ufunc.py +++ b/torchgen/dest/ufunc.py @@ -186,9 +186,9 @@ def compute_ufunc_cuda_functors( ufunc_name = loops[lk].name else: # See Note [ScalarOnly and Generic must match names for CUDA] - assert ( - ufunc_name == loops[lk].name - ), "ScalarOnly and Generic must have same ufunc name" + assert ufunc_name == loops[lk].name, ( + "ScalarOnly and Generic must have same ufunc name" + ) supported_dtypes |= loops[lk].supported_dtypes assert ufunc_name is not None diff --git a/torchgen/executorch/api/custom_ops.py b/torchgen/executorch/api/custom_ops.py index 45f7f8e3cda..641b7e9c941 100644 --- a/torchgen/executorch/api/custom_ops.py +++ b/torchgen/executorch/api/custom_ops.py @@ -65,9 +65,9 @@ class ComputeNativeFunctionStub: {comma.join([r.name for r in f.func.arguments.out])} )""" else: - assert all( - a.type == BaseType(BaseTy.Tensor) for a in f.func.returns - ), f"Only support tensor returns but got {f.func.returns}" + assert all(a.type == BaseType(BaseTy.Tensor) for a in f.func.returns), ( + f"Only support tensor returns but got {f.func.returns}" + ) # Returns a tuple of empty tensors tensor_type = "at::Tensor" comma = ", " diff --git a/torchgen/executorch/api/et_cpp.py b/torchgen/executorch/api/et_cpp.py index d09f0079197..72b0551d029 100644 --- a/torchgen/executorch/api/et_cpp.py +++ b/torchgen/executorch/api/et_cpp.py @@ -181,7 +181,9 @@ def returntype_type(t: Type, *, mutable: bool) -> CType: elif t.name == BaseTy.Scalar: return BaseCType(scalarT) elif isinstance(t, ListType): - assert not mutable, "Native functions should never return a mutable tensor list. They should return void." + assert not mutable, ( + "Native functions should never return a mutable tensor list. They should return void." + ) elem = returntype_type(t.elem, mutable=False) assert t.size is None, f"fixed size list returns not supported: {t}" return VectorCType(elem) diff --git a/torchgen/executorch/model.py b/torchgen/executorch/model.py index 8f80a951ae3..6be7501ebea 100644 --- a/torchgen/executorch/model.py +++ b/torchgen/executorch/model.py @@ -94,9 +94,9 @@ class ETKernelKey: assert type_alias in type_alias_map, "Undefined type alias: " + str( type_alias ) - assert ( - dim_order in dim_order_alias_map - ), f"Undefined dim_order alias: {dim_order}" + assert dim_order in dim_order_alias_map, ( + f"Undefined dim_order alias: {dim_order}" + ) dtype_alias_used.add(type_alias) # Generate all permutations of dtype alias values @@ -193,9 +193,9 @@ class ETKernelIndex: index: dict[OperatorName, BackendMetadata] = {} for op in self.index: kernel_dict = self.index[op] - assert ( - len(kernel_dict.values()) == 1 - ), f"Can't convert ETKernelIndex to BackendIndex because {op} has more than one kernels. Got {kernel_dict}" + assert len(kernel_dict.values()) == 1, ( + f"Can't convert ETKernelIndex to BackendIndex because {op} has more than one kernels. Got {kernel_dict}" + ) index[op] = kernel_dict.get( ETKernelKey(default=True), BackendMetadata(kernel="", structured=False, cpp_namespace=""), diff --git a/torchgen/gen.py b/torchgen/gen.py index 63dd621cdd8..609d338887e 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -1433,9 +1433,9 @@ def get_grouped_by_view_native_functions( assert kind not in grouped_by_views[schema] grouped_by_views[schema][kind] = f else: - assert ( - view_kind not in grouped_by_views[schema] - ), f"{view_kind} already in {grouped_by_views[schema].keys()}" + assert view_kind not in grouped_by_views[schema], ( + f"{view_kind} already in {grouped_by_views[schema].keys()}" + ) grouped_by_views[schema][view_kind] = f return list(concatMap(maybe_create_view_group, grouped_by_views.values())) @@ -1483,9 +1483,9 @@ def get_ns_grouped_kernels( native_function_namespaces.add(namespace) else: namespace = DEFAULT_KERNEL_NAMESPACE - assert ( - len(native_function_namespaces) <= 1 - ), f"Codegen only supports one namespace per operator, got {native_function_namespaces} from {dispatch_keys}" + assert len(native_function_namespaces) <= 1, ( + f"Codegen only supports one namespace per operator, got {native_function_namespaces} from {dispatch_keys}" + ) ns_grouped_kernels[namespace].extend( native_function_decl_gen(f, backend_idx) ) diff --git a/torchgen/gen_backend_stubs.py b/torchgen/gen_backend_stubs.py index 8c6d29258d4..07097010f8f 100644 --- a/torchgen/gen_backend_stubs.py +++ b/torchgen/gen_backend_stubs.py @@ -80,34 +80,34 @@ def parse_backend_yaml( # Mostly just defaulting to false to stick with LazyTensor convention. use_out_as_primary = yaml_values.pop("use_out_as_primary", False) - assert isinstance( - use_out_as_primary, bool - ), f"You must provide either True or False for use_out_as_primary. Provided: {use_out_as_primary}" + assert isinstance(use_out_as_primary, bool), ( + f"You must provide either True or False for use_out_as_primary. Provided: {use_out_as_primary}" + ) use_device_guard = yaml_values.pop("device_guard", False) - assert isinstance( - use_device_guard, bool - ), f"You must provide either True or False for device_guard. Provided: {use_device_guard}" + assert isinstance(use_device_guard, bool), ( + f"You must provide either True or False for device_guard. Provided: {use_device_guard}" + ) supported = yaml_values.pop("supported", []) if supported is None: supported = [] # Allow an empty list of supported ops - assert isinstance( - supported, list - ), f'expected "supported" to be a list, but got: {supported} (of type {type(supported)})' + assert isinstance(supported, list), ( + f'expected "supported" to be a list, but got: {supported} (of type {type(supported)})' + ) symint = yaml_values.pop("symint", []) if symint is None: symint = [] # Allow an empty list of symint ops - assert isinstance( - symint, list - ), f'expected "symint" to be a list, but got: {supported} (of type {type(supported)})' + assert isinstance(symint, list), ( + f'expected "symint" to be a list, but got: {supported} (of type {type(supported)})' + ) symint_set = set(symint) supported_autograd = yaml_values.pop("autograd", []) - assert isinstance( - supported_autograd, list - ), f'expected "autograd" to be a list, but got: {supported_autograd}' + assert isinstance(supported_autograd, list), ( + f'expected "autograd" to be a list, but got: {supported_autograd}' + ) # full_codegen is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py full_codegen = yaml_values.pop("full_codegen", []) @@ -135,9 +135,9 @@ def parse_backend_yaml( metadata: dict[OperatorName, BackendMetadata] = {} for op in backend_ops: op_name = OperatorName.parse(op) - assert ( - op_name in native_functions_map - ), f"Found an invalid operator name: {op_name}" + assert op_name in native_functions_map, ( + f"Found an invalid operator name: {op_name}" + ) # See Note [External Backends Follow Dispatcher API] kernel_name = dispatcher.name(native_functions_map[op_name].func) if op in symint_ops: @@ -238,11 +238,11 @@ the behavior of autograd for some operators on your backend. However "Autograd{b forward_kernels = [f for f in forward_kernels if f is not None] backward_kernels = [f for f in backward_kernels if f is not None] - assert ( - len(forward_kernels) == 0 or len(backward_kernels) == 0 - ), f'Currently, all variants of an op must either be registered to a backend key, or to a backend\'s \ + assert len(forward_kernels) == 0 or len(backward_kernels) == 0, ( + f'Currently, all variants of an op must either be registered to a backend key, or to a backend\'s \ autograd key. They cannot be mix and matched. If this is something you need, feel free to create an issue! \ {forward_kernels[0].kernel} is listed under "supported", but {backward_kernels[0].kernel} is listed under "autograd".' + ) return ParsedExternalYaml( backend_key, autograd_key, class_name, cpp_namespace, backend_indices diff --git a/torchgen/gen_executorch.py b/torchgen/gen_executorch.py index a897bb5e1f9..306333f1eae 100644 --- a/torchgen/gen_executorch.py +++ b/torchgen/gen_executorch.py @@ -444,9 +444,9 @@ def get_ns_grouped_kernels( native_function_namespaces.add(namespace) else: namespace = DEFAULT_KERNEL_NAMESPACE - assert ( - len(native_function_namespaces) <= 1 - ), f"Codegen only supports one namespace per operator, got {native_function_namespaces}" + assert len(native_function_namespaces) <= 1, ( + f"Codegen only supports one namespace per operator, got {native_function_namespaces}" + ) ns_grouped_kernels[namespace].extend( native_function_decl_gen(f, kernel_index) ) diff --git a/torchgen/local.py b/torchgen/local.py index 19045f4a948..8d7016bbfaf 100644 --- a/torchgen/local.py +++ b/torchgen/local.py @@ -40,8 +40,7 @@ def use_const_ref_for_mutable_tensors() -> bool: def use_ilistref_for_tensor_lists() -> bool: assert _locals.use_ilistref_for_tensor_lists is not None, ( - "need to initialize local.use_ilistref_for_tensor_lists with " - "local.parametrize" + "need to initialize local.use_ilistref_for_tensor_lists with local.parametrize" ) return _locals.use_ilistref_for_tensor_lists diff --git a/torchgen/model.py b/torchgen/model.py index 54e03d8c4e6..4c2d13f0b49 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -627,9 +627,9 @@ class NativeFunction: assert isinstance(use_const_ref_for_mutable_tensors, bool) if use_const_ref_for_mutable_tensors: - assert ( - not func.arguments.out - ), "see https://github.com/pytorch/pytorch/issues/145522" + assert not func.arguments.out, ( + "see https://github.com/pytorch/pytorch/issues/145522" + ) variants_s = e.pop("variants", "function") assert isinstance(variants_s, str) @@ -643,9 +643,9 @@ class NativeFunction: raise AssertionError(f"illegal variant {v}") manual_kernel_registration = e.pop("manual_kernel_registration", False) - assert isinstance( - manual_kernel_registration, bool - ), f"not a bool: {manual_kernel_registration}" + assert isinstance(manual_kernel_registration, bool), ( + f"not a bool: {manual_kernel_registration}" + ) manual_cpp_binding = e.pop("manual_cpp_binding", False) assert isinstance(manual_cpp_binding, bool), f"not a bool: {manual_cpp_binding}" @@ -654,9 +654,9 @@ class NativeFunction: assert isinstance(device_guard, bool), f"not a bool: {device_guard}" device_check_s = e.pop("device_check", None) - assert device_check_s is None or isinstance( - device_check_s, str - ), f"not a str: {device_check_s}" + assert device_check_s is None or isinstance(device_check_s, str), ( + f"not a str: {device_check_s}" + ) assert ( device_check_s is None or device_check_s in DeviceCheckType.__members__ ), f"illegal device_check: {device_check_s}" @@ -682,26 +682,26 @@ class NativeFunction: structured_delegate = OperatorName.parse(structured_delegate_s) structured_inherits = e.pop("structured_inherits", None) - assert structured_inherits is None or isinstance( - structured_inherits, str - ), f"not a str: {structured_inherits}" + assert structured_inherits is None or isinstance(structured_inherits, str), ( + f"not a str: {structured_inherits}" + ) assert structured_inherits is None or "::" not in structured_inherits, ( "namespace is not supported in structured inherits," " using the same namespace as the native function" ) python_module = e.pop("python_module", None) - assert python_module is None or isinstance( - python_module, str - ), f"not a str: {python_module}" - assert ( - python_module is None or Variant.method not in variants - ), "functions in modules cannot be methods" + assert python_module is None or isinstance(python_module, str), ( + f"not a str: {python_module}" + ) + assert python_module is None or Variant.method not in variants, ( + "functions in modules cannot be methods" + ) category_override = e.pop("category_override", None) - assert category_override is None or isinstance( - category_override, str - ), f"not a str: {category_override}" + assert category_override is None or isinstance(category_override, str), ( + f"not a str: {category_override}" + ) precomputed_dict = e.pop("precomputed", None) assert precomputed_dict is None or structured is True @@ -740,12 +740,12 @@ class NativeFunction: for ks, v in raw_dispatch.items(): if ks == "__line__": continue # not worth tracking line numbers for dispatch entries - assert isinstance( - ks, str - ), f"illegal dispatch key '{ks}' in {raw_dispatch}" - assert isinstance( - v, str - ), f"illegal dispatch value '{v}' in {raw_dispatch}" + assert isinstance(ks, str), ( + f"illegal dispatch key '{ks}' in {raw_dispatch}" + ) + assert isinstance(v, str), ( + f"illegal dispatch value '{v}' in {raw_dispatch}" + ) for k in ks.split(","): dispatch_key = DispatchKey.parse(k.strip()) num_dispatch_keys += 1 @@ -879,9 +879,9 @@ class NativeFunction: import torchgen.api.ufunc as ufunc for dispatch_key in UFUNC_DISPATCH_KEYS: - assert ( - dispatch_key not in dispatch - ), f"ufunc should not have explicit dispatch entry for {dispatch_key}" + assert dispatch_key not in dispatch, ( + f"ufunc should not have explicit dispatch entry for {dispatch_key}" + ) dispatch[dispatch_key] = BackendMetadata( kernel=ufunc.schema_kernel_name(func, dispatch_key), structured=True, @@ -997,31 +997,31 @@ class NativeFunction: "Put structured field on the out= " "variant of a function; did you mean structured_delegate?" ) - assert ( - self.device_guard - ), "device_guard: False is not respected by structured kernels" + assert self.device_guard, ( + "device_guard: False is not respected by structured kernels" + ) if self.structured_delegate: assert self.func.kind() != SchemaKind.out, ( "structured_delegate field not allowed " "on out= functions; did you mean structured?" ) - assert ( - self.device_guard - ), "device_guard: False is not respected by structured kernels" + assert self.device_guard, ( + "device_guard: False is not respected by structured kernels" + ) # Technically, with the asserts above, this assert is impossible to # happen - assert not ( - self.structured and self.structured_delegate - ), "Cannot have both structured and structured_delegate on function" + assert not (self.structured and self.structured_delegate), ( + "Cannot have both structured and structured_delegate on function" + ) defaulted_arguments = { a.name for a in self.func.schema_order_arguments() if a.default is not None } invalid_args = set.difference(self.cpp_no_default_args, defaulted_arguments) assert len(invalid_args) == 0, f"Invalid cpp_no_default_args: {invalid_args}" if self.structured_inherits is not None: - assert ( - self.structured - ), "structured_inherits must also imply structured: True" + assert self.structured, ( + "structured_inherits must also imply structured: True" + ) if str(self.func.name).startswith("_foreach"): assert self.device_check == DeviceCheckType.NoCheck, ( "foreach kernels fall back to slow path when tensor are on different devices, " @@ -1299,9 +1299,9 @@ class BackendIndex: ) -> None: for k, v in child_index.items(): for op_name, metadata in v.items(): - assert ( - op_name not in parent_index[k] - ), f"duplicate operator {op_name} for dispatch key {k}" + assert op_name not in parent_index[k], ( + f"duplicate operator {op_name} for dispatch key {k}" + ) parent_index[k][op_name] = metadata def primary(self, g: NativeFunctionsGroup) -> NativeFunction: @@ -1450,9 +1450,9 @@ class FunctionSchema: # We also enforce that if you have any mutable, positional args, then they are not returned. # This makes it easier to group these functions properly with their functional/out= counterparts. for a in self.arguments.post_self_positional_mutable: - assert not any( - a.annotation == r.annotation for r in self.returns - ), f"If you have a schema with mutable positional args, we expect them to not be returned. schema: {str(self)}" + assert not any(a.annotation == r.annotation for r in self.returns), ( + f"If you have a schema with mutable positional args, we expect them to not be returned. schema: {str(self)}" + ) # Invariant: we expect out arguments to appear as keyword arguments in the schema. # This means that all mutable returns should be aliased to a keyword argument # (except for "self", which we explicitly don't treat as an out argument because of its use in methods) @@ -1475,9 +1475,9 @@ class FunctionSchema: # (1) It's more annoying to handle properly # (2) It's unnecessary - you can't method-chain on the first (mutated) output because it's part of a tuple. # Instead, we expect the (a!) argument to not be returned. - assert ( - len(mutable_returns) == 0 or len(immutable_returns) == 0 - ), f"NativeFunctions must have either only mutable returns, or only immutable returns. Found: {str(self)}" + assert len(mutable_returns) == 0 or len(immutable_returns) == 0, ( + f"NativeFunctions must have either only mutable returns, or only immutable returns. Found: {str(self)}" + ) for ret in mutable_returns: assert any(ret.annotation == arg.annotation for arg in out_and_self), ( 'All mutable returns must be aliased either to a keyword argument, or to "self". ' @@ -1490,23 +1490,22 @@ class FunctionSchema: # and all other types of out= op schemas should return void. # There are a bunch of existing out= ops that return tuples of tensors though, so we're stuck with allowing that. if any(a.type != BaseType(BaseTy.Tensor) for a in self.arguments.out): - assert ( - len(self.returns) == 0 - ), "out= ops that accept tensor lists as out arguments " + assert len(self.returns) == 0, ( + "out= ops that accept tensor lists as out arguments " + ) "are expected to have no return type (since you can't do method chaining on them)" else: # mutable keyword arguments whose name has _scratch_ prefix are # scratch tensors for memory planning and should not be returned - assert ( - len( - [ - arg - for arg in self.arguments.out - if not arg.name.startswith("_scratch_") - ] - ) - == len(self.returns) - ), "Must return as many arguments as there are out arguments, or no return at all" + assert len( + [ + arg + for arg in self.arguments.out + if not arg.name.startswith("_scratch_") + ] + ) == len(self.returns), ( + "Must return as many arguments as there are out arguments, or no return at all" + ) if self.name.name.inplace: self_a = self.arguments.self_arg @@ -1599,12 +1598,14 @@ class FunctionSchema: if is_inplace: return SchemaKind.inplace elif is_scratch: - assert ( - is_out - ), "invariant: all scratch operators are expected to be out= operators too" + assert is_out, ( + "invariant: all scratch operators are expected to be out= operators too" + ) return SchemaKind.scratch elif is_out: - assert not is_scratch, "We should not categorize a scratch op as an out variant. Check if the order of if statements are expected!" # noqa: B950 + assert not is_scratch, ( + "We should not categorize a scratch op as an out variant. Check if the order of if statements are expected!" + ) # noqa: B950 return SchemaKind.out elif is_mutable: return SchemaKind.mutable @@ -1801,13 +1802,13 @@ class Annotation: before_alias = m.group(1) + (m.group(2) if m.group(2) else "") alias_set = tuple(before_alias.split("|")) is_write = m.group(3) == "!" - assert not ( - is_write and len(alias_set) > 1 - ), f"alias set larger than 1 is not mutable, got {ann} instead." + assert not (is_write and len(alias_set) > 1), ( + f"alias set larger than 1 is not mutable, got {ann} instead." + ) after_set = tuple(m.group(5).split("|")) if m.group(5) else () - assert not ( - len(before_alias) > 1 and len(after_set) > 1 - ), f"before alias set and after alias set cannot be larger than 1 at the same time, got {ann} instead." + assert not (len(before_alias) > 1 and len(after_set) > 1), ( + f"before alias set and after alias set cannot be larger than 1 at the same time, got {ann} instead." + ) r = Annotation( alias_set=alias_set, is_write=is_write, alias_set_after=after_set ) @@ -2348,9 +2349,9 @@ class Arguments: if not arg: continue if arg == "*": - assert ( - arguments_acc is positional - ), "invalid syntax: kwarg-only specifier * can only occur once" + assert arguments_acc is positional, ( + "invalid syntax: kwarg-only specifier * can only occur once" + ) arguments_acc = kwarg_only continue parg = Argument.parse(arg) @@ -2475,9 +2476,9 @@ class Arguments: for a in self.pre_self_positional if a.annotation is not None and a.annotation.is_write ] - assert ( - len(mutable_pre_self_positionals) == 0 - ), "mutable pre_self_positional arguments are not currently supported in the schema" + assert len(mutable_pre_self_positionals) == 0, ( + "mutable pre_self_positional arguments are not currently supported in the schema" + ) # Names that validly are __iXXX__ indicating inplace operations. @@ -2831,9 +2832,9 @@ class Precompute: ) arg, with_list_raw = raw_replace_item.split(" -> ") - assert ( - " " not in arg - ), f"illegal kernel param name '{arg}' in precomputed parameters'" + assert " " not in arg, ( + f"illegal kernel param name '{arg}' in precomputed parameters'" + ) with_list = with_list_raw.split(",") with_list_args = [Argument.parse(name.strip()) for name in with_list] replace[arg] = with_list_args diff --git a/torchgen/utils.py b/torchgen/utils.py index 9647c05bb1e..f660d01677d 100644 --- a/torchgen/utils.py +++ b/torchgen/utils.py @@ -246,9 +246,9 @@ class FileManager: for key in sharded_keys: for shard in all_shards: if key in shard: - assert isinstance( - shard[key], list - ), "sharded keys in base_env must be a list" + assert isinstance(shard[key], list), ( + "sharded keys in base_env must be a list" + ) shard[key] = shard[key].copy() else: shard[key] = [] @@ -441,9 +441,9 @@ class NamespaceHelper: ) -> None: # cpp_namespace can be a colon joined string such as torch::lazy cpp_namespaces = namespace_str.split("::") - assert ( - len(cpp_namespaces) <= max_level - ), f"Codegen doesn't support more than {max_level} level(s) of custom namespace. Got {namespace_str}." + assert len(cpp_namespaces) <= max_level, ( + f"Codegen doesn't support more than {max_level} level(s) of custom namespace. Got {namespace_str}." + ) self.cpp_namespace_ = namespace_str self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces]) self.epilogue_ = "\n".join( diff --git a/torchgen/yaml_utils.py b/torchgen/yaml_utils.py index 0278af84bf6..720d1944602 100644 --- a/torchgen/yaml_utils.py +++ b/torchgen/yaml_utils.py @@ -18,9 +18,9 @@ class YamlLoader(Loader): mapping = [] for key_node, value_node in node.value: key = self.construct_object(key_node, deep=deep) # type: ignore[no-untyped-call] - assert ( - key not in mapping - ), f"Found a duplicate key in the yaml. key={key}, line={node.start_mark.line}" + assert key not in mapping, ( + f"Found a duplicate key in the yaml. key={key}, line={node.start_mark.line}" + ) mapping.append(key) mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call] return mapping