diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index bad399c19dc..9a22e7f209e 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -4,6 +4,7 @@ import functools import itertools import logging import math +import operator import os import pprint import textwrap @@ -2248,9 +2249,13 @@ class Scheduler: self.origin_to_index.update({n: i for i, n in enumerate(n.graph.nodes)}) return self.origin_to_index[n] - origins = [(get_order(e), e) for n in node.get_nodes() for e in n.node.origins] + # Use a dict to have ordering + origins = { + (get_order(e), e): None for n in node.get_nodes() for e in n.node.origins + } + origins = list(origins.keys()) if origins: - _, last = max(origins) + _, last = max(origins, key=operator.itemgetter(0)) V.graph.wrapper_code.enter_context(last) @dynamo_timed