[ghstack-poisoned]
This commit is contained in:
Tom Ritchford 2025-02-08 13:59:04 +00:00
parent e0a9d859ca
commit 0c5508bbf0

View file

@ -1749,8 +1749,7 @@ def is_mutation_op(node: torch.fx.Node) -> bool:
if _mutation_op_re.search(node.target.__name__):
return True
elif node.op == "call_method":
assert isinstance(node.target, str)
if _mutation_op_re.search(node.target):
if _mutation_op_re.search(cast(str, node.target)):
return True
return node.kwargs.get("out") is not None
@ -1834,8 +1833,7 @@ class PatternMatcherPass:
if has_call_module:
nodes.append(graph.find_nodes(op="call_module", sort=False))
pass_name = self.pass_name if self.pass_name is not None else "pattern_matcher"
assert isinstance(gm, torch.fx.GraphModule)
with GraphTransformObserver(gm, pass_name):
with GraphTransformObserver(cast(torch.fx.GraphModule, gm), pass_name):
for node in sorted(itertools.chain.from_iterable(nodes), reverse=True):
target = extract_target(node)
if node.op == "call_module":
@ -2160,6 +2158,5 @@ def extract_target(node: torch.fx.Node) -> torch.fx.node.Target:
as a function.
"""
if node.op == "call_module":
assert isinstance(node.target, str)
return _get_attr(node.graph.owning_module, node.target).__class__
return _get_attr(node.graph.owning_module, cast(str, node.target)).__class__
return node.target