mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Update
[ghstack-poisoned]
This commit is contained in:
parent
e0a9d859ca
commit
0c5508bbf0
1 changed files with 3 additions and 6 deletions
|
|
@ -1749,8 +1749,7 @@ def is_mutation_op(node: torch.fx.Node) -> bool:
|
|||
if _mutation_op_re.search(node.target.__name__):
|
||||
return True
|
||||
elif node.op == "call_method":
|
||||
assert isinstance(node.target, str)
|
||||
if _mutation_op_re.search(node.target):
|
||||
if _mutation_op_re.search(cast(str, node.target)):
|
||||
return True
|
||||
return node.kwargs.get("out") is not None
|
||||
|
||||
|
|
@ -1834,8 +1833,7 @@ class PatternMatcherPass:
|
|||
if has_call_module:
|
||||
nodes.append(graph.find_nodes(op="call_module", sort=False))
|
||||
pass_name = self.pass_name if self.pass_name is not None else "pattern_matcher"
|
||||
assert isinstance(gm, torch.fx.GraphModule)
|
||||
with GraphTransformObserver(gm, pass_name):
|
||||
with GraphTransformObserver(cast(torch.fx.GraphModule, gm), pass_name):
|
||||
for node in sorted(itertools.chain.from_iterable(nodes), reverse=True):
|
||||
target = extract_target(node)
|
||||
if node.op == "call_module":
|
||||
|
|
@ -2160,6 +2158,5 @@ def extract_target(node: torch.fx.Node) -> torch.fx.node.Target:
|
|||
as a function.
|
||||
"""
|
||||
if node.op == "call_module":
|
||||
assert isinstance(node.target, str)
|
||||
return _get_attr(node.graph.owning_module, node.target).__class__
|
||||
return _get_attr(node.graph.owning_module, cast(str, node.target)).__class__
|
||||
return node.target
|
||||
|
|
|
|||
Loading…
Reference in a new issue