[ghstack-poisoned]
This commit is contained in:
Tom Ritchford 2025-02-08 13:54:25 +00:00
parent 8d806ed210
commit e0a9d859ca

View file

@ -1123,7 +1123,6 @@ class ReplacementPatternEntry(PatternEntry):
queue.extend(arg.all_input_nodes)
with graph.inserting_before(last_node):
# assert isinstance(replacement_graph, torch.fx.GraphModule)
replacement_module = cast(torch.fx.GraphModule, replacement_graph)
replacement = Replacer(replacement_module).run(*args)
if isinstance(replacement, torch.fx.Node):
@ -1776,13 +1775,11 @@ def get_mutation_region_id(graph: torch.fx.Graph, node: torch.fx.Node) -> int:
def should_compute_mutation_region_ids(graph: torch.fx.Graph) -> bool:
# assert isinstance(graph.nodes, Iterable)
return "mutation_region_id" not in next(iter(graph.nodes)).meta
def compute_mutation_region_ids(graph: torch.fx.Graph) -> None:
mutation_region_id = 0
# assert isinstance(graph.nodes, Iterable)
for nd in graph.nodes:
if is_mutation_op(nd):
mutation_region_id += 1