diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index 4b8a9abec6a..1cef108a3eb 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -1749,8 +1749,7 @@ def is_mutation_op(node: torch.fx.Node) -> bool: if _mutation_op_re.search(node.target.__name__): return True elif node.op == "call_method": - assert isinstance(node.target, str) - if _mutation_op_re.search(node.target): + if _mutation_op_re.search(cast(str, node.target)): return True return node.kwargs.get("out") is not None @@ -1834,8 +1833,7 @@ class PatternMatcherPass: if has_call_module: nodes.append(graph.find_nodes(op="call_module", sort=False)) pass_name = self.pass_name if self.pass_name is not None else "pattern_matcher" - assert isinstance(gm, torch.fx.GraphModule) - with GraphTransformObserver(gm, pass_name): + with GraphTransformObserver(cast(torch.fx.GraphModule, gm), pass_name): for node in sorted(itertools.chain.from_iterable(nodes), reverse=True): target = extract_target(node) if node.op == "call_module": @@ -2160,6 +2158,5 @@ def extract_target(node: torch.fx.Node) -> torch.fx.node.Target: as a function. """ if node.op == "call_module": - assert isinstance(node.target, str) - return _get_attr(node.graph.owning_module, node.target).__class__ + return _get_attr(node.graph.owning_module, cast(str, node.target)).__class__ return node.target