From 00d962631c2b17275505a92cad168bc8a9ffe74d Mon Sep 17 00:00:00 2001 From: isdanni Date: Fri, 13 Oct 2023 22:20:00 +0000 Subject: [PATCH] [BE] Enable Ruff's Flake8 PYI045 (#111184) Enable [iter-method-return-iterable (PYI045)](https://docs.astral.sh/ruff/rules/iter-method-return-iterable/#iter-method-return-iterable-pyi045) Link: #110950 Pull Request resolved: https://github.com/pytorch/pytorch/pull/111184 Approved by: https://github.com/Skylion007 --- pyproject.toml | 1 - torch/distributed/pipeline/sync/pipe.py | 4 ++-- torch/fx/proxy.py | 4 ++-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d28df012a24..bd9a330f8c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,6 @@ ignore = [ "PYI024", "PYI036", "PYI041", - "PYI045", "PYI056", "SIM102", "SIM103", "SIM112", # flake8-simplify code styles "SIM105", # these ignores are from flake8-simplify. please fix or ignore with commented reason diff --git a/torch/distributed/pipeline/sync/pipe.py b/torch/distributed/pipeline/sync/pipe.py index 65063e9b1c8..5e61341d9ad 100644 --- a/torch/distributed/pipeline/sync/pipe.py +++ b/torch/distributed/pipeline/sync/pipe.py @@ -6,7 +6,7 @@ # LICENSE file in the root directory of this source tree. """The Pipe interface.""" from collections import OrderedDict -from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union, Sequence, Tuple, cast +from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Union, Sequence, Tuple, cast import torch from torch import Tensor, nn @@ -379,7 +379,7 @@ class Pipe(Module): raise IndexError - def __iter__(self) -> Iterable[nn.Module]: + def __iter__(self) -> Iterator[nn.Module]: """Iterates over children of the underlying sequential module.""" for partition in self.partitions: yield from partition diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index 51f7f8654fc..fd62b1b2baa 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -12,7 +12,7 @@ from dataclasses import is_dataclass, fields from .graph import magic_methods, reflectable_magic_methods, Graph -from typing import Tuple, Dict, OrderedDict, Optional, Iterable, Any, Iterator, Callable +from typing import Tuple, Dict, OrderedDict, Optional, Any, Iterator, Callable from .node import Target, Node, Argument, base_types, map_aggregate from ._compatibility import compatibility from .operator_schemas import check_for_mutable_operation @@ -392,7 +392,7 @@ class Proxy: def __call__(self, *args, **kwargs) -> 'Proxy': return self.tracer.create_proxy('call_method', '__call__', (self,) + args, kwargs) - def __iter__(self) -> Iterable['Proxy']: + def __iter__(self) -> Iterator['Proxy']: frame = inspect.currentframe() assert frame is not None calling_frame = frame.f_back