mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[inductor][scheduler] Use set for origin (#119861)
xref - https://github.com/pytorch/pytorch/issues/119440 This avoids node > node comparison if the origin order is same in the origins tuple. However, I am unable to come up with a test case where this could happen. Pull Request resolved: https://github.com/pytorch/pytorch/pull/119861 Approved by: https://github.com/Skylion007, https://github.com/eellison
This commit is contained in:
parent
29235c7063
commit
6b04251b87
1 changed files with 7 additions and 2 deletions
|
|
@ -4,6 +4,7 @@ import functools
|
|||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import operator
|
||||
import os
|
||||
import pprint
|
||||
import textwrap
|
||||
|
|
@ -2248,9 +2249,13 @@ class Scheduler:
|
|||
self.origin_to_index.update({n: i for i, n in enumerate(n.graph.nodes)})
|
||||
return self.origin_to_index[n]
|
||||
|
||||
origins = [(get_order(e), e) for n in node.get_nodes() for e in n.node.origins]
|
||||
# Use a dict to have ordering
|
||||
origins = {
|
||||
(get_order(e), e): None for n in node.get_nodes() for e in n.node.origins
|
||||
}
|
||||
origins = list(origins.keys())
|
||||
if origins:
|
||||
_, last = max(origins)
|
||||
_, last = max(origins, key=operator.itemgetter(0))
|
||||
V.graph.wrapper_code.enter_context(last)
|
||||
|
||||
@dynamo_timed
|
||||
|
|
|
|||
Loading…
Reference in a new issue