mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Update
[ghstack-poisoned]
This commit is contained in:
parent
8d806ed210
commit
e0a9d859ca
1 changed files with 0 additions and 3 deletions
|
|
@ -1123,7 +1123,6 @@ class ReplacementPatternEntry(PatternEntry):
|
|||
queue.extend(arg.all_input_nodes)
|
||||
|
||||
with graph.inserting_before(last_node):
|
||||
# assert isinstance(replacement_graph, torch.fx.GraphModule)
|
||||
replacement_module = cast(torch.fx.GraphModule, replacement_graph)
|
||||
replacement = Replacer(replacement_module).run(*args)
|
||||
if isinstance(replacement, torch.fx.Node):
|
||||
|
|
@ -1776,13 +1775,11 @@ def get_mutation_region_id(graph: torch.fx.Graph, node: torch.fx.Node) -> int:
|
|||
|
||||
|
||||
def should_compute_mutation_region_ids(graph: torch.fx.Graph) -> bool:
|
||||
# assert isinstance(graph.nodes, Iterable)
|
||||
return "mutation_region_id" not in next(iter(graph.nodes)).meta
|
||||
|
||||
|
||||
def compute_mutation_region_ids(graph: torch.fx.Graph) -> None:
|
||||
mutation_region_id = 0
|
||||
# assert isinstance(graph.nodes, Iterable)
|
||||
for nd in graph.nodes:
|
||||
if is_mutation_op(nd):
|
||||
mutation_region_id += 1
|
||||
|
|
|
|||
Loading…
Reference in a new issue