diff --git a/test/test_fx.py b/test/test_fx.py index dcb50d5cd38..586897adcd4 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -2768,7 +2768,7 @@ class TestFX(JitTestCase): return self.other(x) traced = symbolic_trace(ReturnTypeModule()) - self.assertIn("-> list[str]", traced._code) + self.assertIn("-> typing_List[str]", traced._code) scripted = torch.jit.script(traced) self.assertIn("-> List[str]", scripted.code) @@ -3567,8 +3567,8 @@ class TestFX(JitTestCase): traced(x, y) - FileCheck().check("tuple[()]") \ - .check("tuple[str,tuple[()]]") \ + FileCheck().check("typing_Tuple[()]") \ + .check("typing_Tuple[str,typing_Tuple[()]]") \ .run(traced.code) scripted = torch.jit.script(traced) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 5698d76d66c..89f97b2419c 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -9,6 +9,7 @@ import keyword import math import os import re +import typing import warnings from collections import defaultdict from collections.abc import Iterable @@ -33,12 +34,13 @@ if TYPE_CHECKING: # Mapping of builtins to their `typing` equivalent. +# (PEP585: See D68459095 test plan) _origin_type_map = { - list: list, - dict: dict, - set: set, - frozenset: frozenset, - tuple: tuple, + list: typing.List, + dict: typing.Dict, + set: typing.Set, + frozenset: typing.FrozenSet, + tuple: typing.Tuple, } _legal_ops = dict.fromkeys(