[BE][Ez]: FURB148 - remove useless enumerate calls (#145619)

Remove useless enumerate calls

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145619
Approved by: https://github.com/drisspg
This commit is contained in:
Aaron Gokaslan 2025-01-24 23:37:12 +00:00 committed by PyTorch MergeBot
parent 0741963e01
commit f3304571fc
15 changed files with 20 additions and 20 deletions

View file

@ -182,7 +182,7 @@ class TimeoutTest(TestCase):
threads.append(t)
t.start()
for _, thread in enumerate(threads):
for thread in threads:
thread.join()
# we expect the world_size-1 threads to have failed

View file

@ -48,7 +48,7 @@ class TestFakePG(TestCase):
input_tensor = torch.ones(3, 3) * dist.get_rank()
output_tensors = [torch.empty_like(input_tensor) for _ in range(2)]
dist.all_gather(output_tensors, input_tensor)
for _, out_tensor in enumerate(output_tensors):
for out_tensor in output_tensors:
self.assertEqual(tuple(out_tensor.shape), (3, 3))
def test_reduce_scatter(self):

View file

@ -145,7 +145,7 @@ class TestAliasAnalysis(JitTestCase):
def forward(self, x):
return x + 2
for _, fname in enumerate(fnames):
for fname in fnames:
mod = torch.jit.script(MyModuleCUTest())
torch.jit.save(mod, fname)
loaded_mod = torch.jit.load(fname)

View file

@ -845,7 +845,7 @@ class TestProfiler(TestCase):
super().__init__(*args, **kwargs)
def train():
for _, data in enumerate(dataloader):
for data in dataloader:
x, y = data[0], data[1]
y_pred = model(x)
loss = criterion(y_pred, y)

View file

@ -2499,7 +2499,7 @@ class TestTEFuser(JitTestCase):
for i, func in enumerate(funcs):
num_args = i + 1
for j, gen in enumerate(gen_tensor):
for gen in gen_tensor:
inps = (gen(n), gen(n), gen(n))
func_s = torch.jit.trace(func, inps, check_trace=False)
torch._C._jit_pass_erase_shape_information(func_s.graph)

View file

@ -1686,7 +1686,7 @@ class TestCommon(TestCase):
def test_meta_consistency_out_dtype_mismatch(self, device, dtype, op):
samples = op.sample_inputs(device, dtype)
for i, sample in enumerate(samples):
for sample in samples:
input, args, kwargs = (sample.input, sample.args, sample.kwargs)
try:
@ -2763,7 +2763,7 @@ class TestFakeTensor(TestCase):
def _test_fake_crossref_helper(self, device, dtype, op, context):
samples = op.sample_inputs(device, dtype, requires_grad=True)
for iter, sample in enumerate(samples):
for sample in samples:
args = [sample.input] + list(sample.args)
kwargs = sample.kwargs

View file

@ -1358,7 +1358,7 @@ class TestFFT(TestCase):
'onesided': True,
},
]
for i, pattern in enumerate(patterns):
for pattern in patterns:
_test_istft_is_inverse_of_stft(pattern)
@onlyNativeDeviceTypes
@ -1425,7 +1425,7 @@ class TestFFT(TestCase):
'onesided': True,
},
]
for i, pattern in enumerate(patterns):
for pattern in patterns:
_test_istft_is_inverse_of_stft_with_padding(pattern)
@onlyNativeDeviceTypes

View file

@ -116,7 +116,7 @@ class TestTypes(TestCase):
)
def test_type_create(self):
for _, atype in enumerate(types):
for atype in types:
a = np.array([1, 2, 3], atype)
b = atype([1, 2, 3])
assert_equal(a, b)

View file

@ -190,7 +190,7 @@ def preprocess(
# Indicates whether this is the first line inside Python
# code block (i.e. for, while, if, elif, else)
python_block_start = True
for i, input_line in enumerate(input_lines):
for input_line in input_lines:
if input_line == "":
blank_lines += 1
continue

View file

@ -63,7 +63,7 @@ class TestSetLinter(LinterTestCase):
{0: 25, 2: 24, 3: 23, 6: 8, 12: 14, 18: 22},
),
)
for i, (s, expected) in enumerate(TESTS):
for s, expected in TESTS:
pl = python_lines(s)
if s:
actual = pl.token_lines[0].bracket_pairs
@ -83,7 +83,7 @@ class TestSetLinter(LinterTestCase):
("{1, 2}", 1),
("{One({'a': 1}), Two([{}, {2}, {1, 2}])}", 3),
)
for i, (s, expected) in enumerate(TESTS):
for s, expected in TESTS:
pl = python_lines(s)
actual = pl.token_lines and pl.token_lines[0].braced_sets
self.assertEqual(len(actual), expected)

View file

@ -457,7 +457,7 @@ def save_tensors_and_symints_for_backward(ctx, args):
), args
partitioned_args: list[Any] = [[], []]
pos = []
for i, arg in enumerate(args):
for arg in args:
idx = 0 if isinstance(arg, torch.Tensor) else 1
partitioned_args[idx].append(arg)
pos.append(idx)

View file

@ -1416,7 +1416,7 @@ class SIMDScheduling(BaseScheduling):
kernel.finalize_indexing(all_indexing.keys())
# Second pass to do codegen
for i, node in enumerate(node_schedule):
for node in node_schedule:
if node is DisableReduction:
stack.enter_context(kernel.disable_reduction())
elif node is EnableReduction:

View file

@ -893,15 +893,15 @@ def visualize_results(
"""
html_content += "<tr><th>\\</th>"
for i, col_name in enumerate(input_list):
for col_name in input_list:
col = "<br>".join(col_name)
html_content += f"<th>{col}</th>"
html_content += "</tr></thead><tbody>"
# Add table rows
for i, row_name in enumerate(input_list):
for row_name in input_list:
html_content += f"<tr><th>{row_name}</th>"
for j, col_name in enumerate(input_list):
for col_name in input_list:
# Determine the status class for the cell
status_enum = results.lookup((row_name, col_name))
status_class = ""

View file

@ -411,7 +411,7 @@ def topological_sort_lpmf(
# compute the amount of memory that is allocated when a node is scheduled
# and the amount of memory that can be freed when a node is scheduled
for i, node in enumerate(nodes):
for node in nodes:
# 1. if a buffer read by this node is last used by this node
for buf in node.mpi_node.pred_buffers:
if buf_info[buf]["outdegree"] == 1:

View file

@ -634,7 +634,7 @@ class UnflattenedModule(torch.nn.Module):
for orig_fqn, indexed_call_modules in called_modules.items():
call_modules = [mod for _, mod in sorted(indexed_call_modules)]
if len(call_modules) > 1:
for i, call_module in enumerate(call_modules):
for i in range(len(call_modules)):
fqn = _call_name(orig_fqn, i + 1)
if fqn not in redirected_call_indices:
*prefix, name = fqn.split(".")