diff --git a/benchmarks/dynamo/timm_models.py b/benchmarks/dynamo/timm_models.py index a59a24132d4..9e08b37fd55 100755 --- a/benchmarks/dynamo/timm_models.py +++ b/benchmarks/dynamo/timm_models.py @@ -167,11 +167,9 @@ def refresh_model_names(): del all_models_family[key] chosen_models = set() - for value in docs_models_family.values(): - chosen_models.add(value[0]) + chosen_models.update(value[0] for value in docs_models_family.values()) - for key, value in all_models_family.items(): - chosen_models.add(value[0]) + chosen_models.update(value[0] for key, value in all_models_family.items()) filename = "timm_models_list.txt" if os.path.exists("benchmarks"): diff --git a/benchmarks/operator_benchmark/benchmark_utils.py b/benchmarks/operator_benchmark/benchmark_utils.py index d68a0504b88..d7e45b7c168 100644 --- a/benchmarks/operator_benchmark/benchmark_utils.py +++ b/benchmarks/operator_benchmark/benchmark_utils.py @@ -345,8 +345,9 @@ def get_operator_range(chars_range): ops_start_chars_set.add(item.lower()) continue start, end = item.split("-") - for c in range(ord(start), ord(end) + 1): - ops_start_chars_set.add(chr(c).lower()) + ops_start_chars_set.update( + chr(c).lower() for c in range(ord(start), ord(end) + 1) + ) return ops_start_chars_set diff --git a/test/distributed/_composable/fully_shard/test_fully_shard_init.py b/test/distributed/_composable/fully_shard/test_fully_shard_init.py index d5297f4cc10..c6da3ab295f 100644 --- a/test/distributed/_composable/fully_shard/test_fully_shard_init.py +++ b/test/distributed/_composable/fully_shard/test_fully_shard_init.py @@ -144,10 +144,12 @@ class TestInitialization(FSDPTest): # Check that the composable module does not add any wrapper class local_module_classes = set() composable_module_classes = set() - for submodule in local_model.modules(): - local_module_classes.add(type(submodule)) - for submodule in composable_module.modules(): - composable_module_classes.add(type(submodule)) + local_module_classes.update( + type(submodule) for submodule in local_model.modules() + ) + composable_module_classes.update( + type(submodule) for submodule in composable_module.modules() + ) self.assertEqual(local_module_classes, composable_module_classes) # Check that the composable module has the same FSDP states with the @@ -310,14 +312,14 @@ class TestInitialization(FSDPTest): ] for data_structure_name in data_structure_names: all_structures = set() - for module in ( - composable_module.u1, - composable_module.u2, - composable_module, - ): - all_structures.add( - id(getattr(fully_shard.state(module), data_structure_name)) + all_structures.update( + id(getattr(fully_shard.state(module), data_structure_name)) + for module in ( + composable_module.u1, + composable_module.u2, + composable_module, ) + ) self.assertEqual(len(all_structures), 1) diff --git a/test/distributed/fsdp/test_wrap.py b/test/distributed/fsdp/test_wrap.py index 1bb7b2c8849..d39ba373074 100644 --- a/test/distributed/fsdp/test_wrap.py +++ b/test/distributed/fsdp/test_wrap.py @@ -945,8 +945,7 @@ class TestWrapUtils(TestCase): ignored_params = set() for module_name, module in model.named_modules(): if "lora_A" in module_name: - for param in module.parameters(): - ignored_params.add(param) + ignored_params.update(module.parameters()) _validate_frozen_params(model, modules_to_wrap, ignored_params, use_orig_params) diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 4b46c568aff..9f1819570a6 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -1375,8 +1375,7 @@ def forward(self, getitem, const): cond_gm = backend.graphs[0] name_set = set() - for name, _ in cond_gm.named_modules(): - name_set.add(name) + name_set.update(name for name, _ in cond_gm.named_modules()) self.assertEqual( name_set, { @@ -1735,8 +1734,7 @@ def forward(self): self.assertEqual(result, x + y + x) wrap_gm = backend.graphs[0] names = set() - for mod_name, _ in wrap_gm.named_modules(): - names.add(mod_name) + names.update(mod_name for mod_name, _ in wrap_gm.named_modules()) self.assertEqual( names, { diff --git a/test/functorch/discover_coverage.py b/test/functorch/discover_coverage.py index f6e08581485..dd7e8b6c9ae 100644 --- a/test/functorch/discover_coverage.py +++ b/test/functorch/discover_coverage.py @@ -365,8 +365,7 @@ def get_all_tested_ops(): result = set({}) for op in get_covered_ops(overridable_outplace_we_care_about).values(): opinfos = op_to_opinfo[op] - for opinfo in opinfos: - result.add(opinfo.name) + result.update(opinfo.name for opinfo in opinfos) return result diff --git a/test/package/test_digraph.py b/test/package/test_digraph.py index 90dc11f3a10..9868466b64a 100644 --- a/test/package/test_digraph.py +++ b/test/package/test_digraph.py @@ -79,8 +79,7 @@ class TestDiGraph(PackageTestCase): g.add_node(3) nodes = set() - for n in g: - nodes.add(n) + nodes.update(g) self.assertEqual(nodes, {1, 2, 3}) diff --git a/test/test_dataloader.py b/test/test_dataloader.py index a921f59dad9..cd8600a573f 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -1617,8 +1617,7 @@ except RuntimeError as e: dataset = SynchronizedSeedDataset(num_workers, batch_size, num_workers) dataloader = self._get_data_loader(dataset, batch_size=batch_size, num_workers=num_workers) seeds = set() - for batch in dataloader: - seeds.add(batch[0]) + seeds.update(batch[0] for batch in dataloader) self.assertEqual(len(seeds), num_workers) def test_worker_seed_reproducibility(self): diff --git a/test/test_torch.py b/test/test_torch.py index 0dbcc0fa793..735a4f447ae 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -9523,8 +9523,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j], device_set = {'cpu', 'cpu:0', 'cuda', 'cuda:0', 'cuda:1', 'cuda:10', 'cuda:100'} device_hash_set = set() - for device in device_set: - device_hash_set.add(hash(torch.device(device))) + device_hash_set.update(hash(torch.device(device)) for device in device_set) self.assertEqual(len(device_set), len(device_hash_set)) def get_expected_device_repr(device): diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 72a880ab48a..ba255249159 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3233,17 +3233,19 @@ if torch.distributed.is_available(): @functools.lru_cache(None) def get_legacy_mod_inlinelist(): - inlinelist = set() - for m in LEGACY_MOD_INLINELIST: - inlinelist.add(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) + inlinelist = { + _module_dir(torch) + m[len("torch.") :].replace(".", "/") + for m in LEGACY_MOD_INLINELIST + } return inlinelist @functools.lru_cache(None) def get_mod_inlinelist(): - inlinelist = set() - for m in MOD_INLINELIST: - inlinelist.add(_module_dir(torch) + m[len("torch.") :].replace(".", "/")) + inlinelist = { + _module_dir(torch) + m[len("torch.") :].replace(".", "/") + for m in MOD_INLINELIST + } return inlinelist diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index fd7fba3e8f5..873441a971a 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -744,8 +744,7 @@ def min_cut_rematerialization_partition( if node.op == "placeholder" and "tangents" in node.target: required_bw_nodes.add(node) if node in required_bw_nodes: - for user in node.users: - required_bw_nodes.add(user) + required_bw_nodes.update(node.users) primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) fwd_seed_offset_inputs = list( diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index e8ca0dd18bf..e4b30f0ba19 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -3623,8 +3623,7 @@ class CppScheduling(BaseScheduling): if var_ranges is None: var_ranges = v assert var_ranges == v, (var_ranges, v, node.snodes) - for expr in exprs: - indexing_exprs.add(expr) + indexing_exprs.update(exprs) return var_ranges, list(indexing_exprs) else: assert isinstance(node, SchedulerNode) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index e42176a769e..97e16831208 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -635,8 +635,7 @@ class GraphLowering(torch.fx.Interpreter): # - sebotnet33ts_256 for n in self.module.graph.nodes: if n in output_set: - for child in n.users: - output_set.add(child) + output_set.update(n.users) return output_set diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 57d48f1f4c7..7fd89ab3bfc 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -89,8 +89,9 @@ def add_needs_realized_inputs(fn): return [add_needs_realized_inputs(x) for x in fn] needs_realized_inputs.add(fn) if isinstance(fn, torch._ops.OpOverloadPacket): - for overload in fn.overloads(): - needs_realized_inputs.add(getattr(fn, overload)) + needs_realized_inputs.update( + getattr(fn, overload) for overload in fn.overloads() + ) def add_layout_constraint(fn, constraint): diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 30aa54843d1..32f734ba8b3 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2292,9 +2292,7 @@ class Scheduler: Populate node.last_usage recursively (also for the nodes within a FusedSchedulerNode) """ - future_used_buffers = set() - for node_name in V.graph.get_output_names(): - future_used_buffers.add(node_name) + future_used_buffers = set(V.graph.get_output_names()) for node in reversed(self.nodes): node.set_last_usage(future_used_buffers, self.mutation_real_name) diff --git a/torch/_library/custom_ops.py b/torch/_library/custom_ops.py index f21d22651eb..bd36d6b652e 100644 --- a/torch/_library/custom_ops.py +++ b/torch/_library/custom_ops.py @@ -223,9 +223,10 @@ class CustomOpDef: def backend_impl(*args, **kwargs): # Checks the assumption that outputs cannot alias # inputs or other outputs. - storages = set() - for tensor in iter_tensors(args, kwargs): - storages.add(id(tensor.untyped_storage())) + storages = { + id(tensor.untyped_storage()) + for tensor in iter_tensors(args, kwargs) + } result = self._backend_fns[device_type](*args, **kwargs) diff --git a/torch/ao/ns/fx/n_shadows_utils.py b/torch/ao/ns/fx/n_shadows_utils.py index b7eddf93e2a..1fd6f069ac8 100644 --- a/torch/ao/ns/fx/n_shadows_utils.py +++ b/torch/ao/ns/fx/n_shadows_utils.py @@ -742,8 +742,7 @@ def create_add_loggers_graph( insert_submodule_copy = False if maybe_subgraph is not None: first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1] - for node_to_skip in maybe_subgraph: - nodes_to_skip.add(node_to_skip) + nodes_to_skip.update(maybe_subgraph) qconfig = node_name_to_qconfig[first_node.name] if qconfig is not None: insert_submodule_copy = True @@ -873,8 +872,7 @@ def create_add_loggers_graph( maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup) if maybe_subgraph is not None: first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1] - for node_to_skip in maybe_subgraph: - nodes_to_skip.add(node_to_skip) + nodes_to_skip.update(maybe_subgraph) else: first_node, last_node = n, n diff --git a/torch/ao/quantization/quantizer/embedding_quantizer.py b/torch/ao/quantization/quantizer/embedding_quantizer.py index 8ffd2002e58..81306943264 100644 --- a/torch/ao/quantization/quantizer/embedding_quantizer.py +++ b/torch/ao/quantization/quantizer/embedding_quantizer.py @@ -45,9 +45,9 @@ class EmbeddingQuantizer(Quantizer): @classmethod def get_supported_quantization_configs(cls) -> List[QuantizationConfig]: - op_configs: Set[QuantizationConfig] = set({}) - for spec, _ in cls.get_supported_operators(): - op_configs.add(spec) + op_configs: Set[QuantizationConfig] = { + spec for spec, _ in cls.get_supported_operators() + } return list(op_configs) @classmethod diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 8889cf2df0c..269b0128c66 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -286,9 +286,9 @@ class X86InductorQuantizer(Quantizer): @classmethod def get_supported_quantization_configs(cls) -> List[QuantizationConfig]: - op_configs: Set[QuantizationConfig] = set({}) - for spec, _ in cls.supported_config_and_operators: - op_configs.add(spec) + op_configs: Set[QuantizationConfig] = { + spec for spec, _ in cls.supported_config_and_operators + } return list(op_configs) @classmethod diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/torch/ao/quantization/quantizer/xnnpack_quantizer.py index 1f7dad387ff..b66cfdf37a7 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer.py @@ -305,9 +305,9 @@ class XNNPACKQuantizer(Quantizer): @classmethod def get_supported_quantization_configs(cls) -> List[QuantizationConfig]: - op_configs: Set[QuantizationConfig] = set({}) - for spec, _ in cls.supported_config_and_operators: - op_configs.add(spec) + op_configs: Set[QuantizationConfig] = { + spec for spec, _ in cls.supported_config_and_operators + } return list(op_configs) @classmethod diff --git a/torch/distributed/_tensor/ops/basic_strategy.py b/torch/distributed/_tensor/ops/basic_strategy.py index 80055281236..6c2d87f470d 100644 --- a/torch/distributed/_tensor/ops/basic_strategy.py +++ b/torch/distributed/_tensor/ops/basic_strategy.py @@ -1,7 +1,7 @@ import itertools from dataclasses import dataclass -from typing import List, Tuple +from typing import List, Set, Tuple from torch.distributed._tensor.op_schema import OpStrategy, PlacementStrategy from torch.distributed._tensor.placement_types import ( @@ -44,10 +44,9 @@ class EinsumDims: Parse the dims and extract the contracting, batch, and free dimensions for the left and right hand sides. """ - dim_char_set = set() + dim_char_set: Set[str] = set() for input_dim in input_dims: - for input_char in list(input_dim): - dim_char_set.add(input_char) + dim_char_set.update(input_dim) # get a determinisitc order of all dim chars all_dim_chars = sorted(dim_char_set) diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index 84659586ea8..a8f8216057a 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -218,7 +218,7 @@ def _verify_options( fqn_param_mapping[fqn] = param all_fqns.add(fqn) - submodule_prefixes = set() + submodule_prefixes: Set[str] = set() if submodules: submodules = set(submodules) for name, module in model.named_modules(): @@ -226,8 +226,7 @@ def _verify_options( continue fqns = _get_fqns(model, name) assert len(fqns) == 1, "Submodule FQN should only have 1 instance" - for fqn in fqns: - submodule_prefixes.add(f"{fqn}.") + submodule_prefixes.update(f"{fqn}." for fqn in fqns) fsdp_modules = FSDP.fsdp_modules(model) state_dict_config: StateDictConfig diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 9569a0d01b5..3e797638abd 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -112,9 +112,7 @@ def _format_import_statement(name: str, obj: Any, importer: Importer) -> str: def _format_import_block(globals: Dict[str, Any], importer: Importer): - import_strs: Set[str] = set() - for name, obj in globals.items(): - import_strs.add(_format_import_statement(name, obj, importer)) + import_strs: Set[str] = {_format_import_statement(name, obj, importer) for name, obj in globals.items()} # Sort the imports so we have a stable import block that allows us to # hash the graph module and get a consistent key for use in a cache. return "\n".join(sorted(import_strs)) diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index b4972720a05..d0bb4b55a40 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -294,8 +294,7 @@ def _replace_pattern( # Copy the replacement graph over user_nodes: Set[Node] = set() for n in match.returning_nodes: - for user in n.users: - user_nodes.add(user) + user_nodes.update(n.users) assert user_nodes, "The returning_nodes should have at least one user node" if len(user_nodes) == 1: diff --git a/torch/profiler/_memory_profiler.py b/torch/profiler/_memory_profiler.py index f091dd47d03..be3edc50655 100644 --- a/torch/profiler/_memory_profiler.py +++ b/torch/profiler/_memory_profiler.py @@ -930,8 +930,9 @@ class MemoryProfile: self._is_gradient(*i) or i in used_for_gradient for i in node.outputs.items() ): - for key, (_, version) in node.inputs.items(): - used_for_gradient.add((key, version)) + used_for_gradient.update( + (key, version) for key, (_, version) in node.inputs.items() + ) candidate_parameters.intersection_update(used_for_gradient) # and depends on a gradient. diff --git a/torch/utils/data/datapipes/_hook_iterator.py b/torch/utils/data/datapipes/_hook_iterator.py index 7463cc55d27..49e17438d60 100644 --- a/torch/utils/data/datapipes/_hook_iterator.py +++ b/torch/utils/data/datapipes/_hook_iterator.py @@ -34,9 +34,7 @@ def _strip_datapipe_from_name(name: str) -> str: def _generate_input_args_string(obj): """Generate a string for the input arguments of an object.""" signature = inspect.signature(obj.__class__) - input_param_names = set() - for param_name in signature.parameters.keys(): - input_param_names.add(param_name) + input_param_names = set(signature.parameters.keys()) result = [] for name, value in inspect.getmembers(obj): if name in input_param_names: diff --git a/torch/utils/tensorboard/_caffe2_graph.py b/torch/utils/tensorboard/_caffe2_graph.py index 53674602605..cd2d371204c 100644 --- a/torch/utils/tensorboard/_caffe2_graph.py +++ b/torch/utils/tensorboard/_caffe2_graph.py @@ -578,10 +578,8 @@ def _compute_in_out(ops): out_blobs = set() for op in ops: - for input_blob in op.input: - in_blobs.add(input_blob) - for output_blob in op.output: - out_blobs.add(output_blob) + in_blobs.update(op.input) + out_blobs.update(op.output) input_blobs = list(in_blobs.difference(out_blobs)) output_blobs = list(out_blobs.difference(in_blobs)) @@ -700,8 +698,7 @@ def _operators_to_graph_def( else [_operator_to_node(shapes, op)] ) # .extend() expects an iterable current_graph.node.extend(nodes_from_op) - for input_blob in op.input: - blobs.add(input_blob) + blobs.update(op.input) for i, output_blob in enumerate(op.output): blobs.add(output_blob) producing_ops.setdefault(output_blob, []).append((op, i)) diff --git a/torchgen/gen.py b/torchgen/gen.py index 2549fd175c0..dee23957e3e 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -2125,7 +2125,7 @@ def gen_headers( ) def gen_aten_interned_strings() -> Dict[str, str]: - attrs = set() # All function argument names + attrs: Set[str] = set() # All function argument names names = set() # All ATen function names for func in native_functions: names.add(str(func.func.name.name)) @@ -2133,8 +2133,7 @@ def gen_headers( # symbol without the underscore names.add(func.func.name.name.base) - for arg in func.func.schema_order_arguments(): - attrs.add(arg.name) + attrs.update(arg.name for arg in func.func.schema_order_arguments()) # These are keywords in C++, so aren't valid symbol names # https://en.cppreference.com/w/cpp/language/operator_alternative