From 829b49b867b12f6c5f55e423ee22152ef2b7e172 Mon Sep 17 00:00:00 2001 From: "Tugsbayasgalan (Tugsuu) Manlaibaatar" Date: Tue, 7 Dec 2021 14:15:23 -0800 Subject: [PATCH] Output UnionType str rep with () instead of [] (#69502) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69502 Test Plan: Imported from OSS Reviewed By: pbelevich Differential Revision: D32902781 Pulled By: tugsbayasgalan fbshipit-source-id: 67a73b209575437477cdbd3eb8f685019709e99c --- aten/src/ATen/core/jit_type.h | 4 +++- aten/src/ATen/core/type.cpp | 11 +++++++---- test/jit/test_union.py | 29 ++++++++++++++++++++++------- 3 files changed, 32 insertions(+), 12 deletions(-) diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 56856de49ca..212138351f0 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -147,7 +147,9 @@ struct TORCH_API UnionType : public Type { protected: explicit UnionType(std::vector types, TypeKind kind=TypeKind::UnionType); std::string annotation_str_impl(TypePrinter printer = nullptr) const override; - std::string unionStr(TypePrinter printer = nullptr, bool is_annotation_str = false) const; + std::string unionStr( + TypePrinter printer = nullptr, + bool is_annotation_str = false) const; // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) bool has_free_variables_; // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index 20a531792f9..654f645767b 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -1099,8 +1099,8 @@ bool UnionType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const { }); } - -std::string UnionType::unionStr(TypePrinter printer, bool is_annotation_str) const { +std::string UnionType::unionStr(TypePrinter printer, bool is_annotation_str) + const { std::stringstream ss; bool can_hold_numbertype = this->canHoldType(*NumberType::get()); @@ -1116,7 +1116,10 @@ std::string UnionType::unionStr(TypePrinter printer, bool is_annotation_str) con return false; }; - ss << "Union["; + std::string open_delimeter = is_annotation_str ? "[" : "("; + std::string close_delimeter = is_annotation_str ? "]" : ")"; + + ss << "Union" + open_delimeter; bool printed = false; for (size_t i = 0; i < types_.size(); ++i) { if (!can_hold_numbertype || !is_numbertype(types_[i])) { @@ -1141,7 +1144,7 @@ std::string UnionType::unionStr(TypePrinter printer, bool is_annotation_str) con ss << NumberType::get()->str(); } } - ss << "]"; + ss << close_delimeter; return ss.str(); } diff --git a/test/jit/test_union.py b/test/jit/test_union.py index 24e7419a642..c5b9e59bcb9 100644 --- a/test/jit/test_union.py +++ b/test/jit/test_union.py @@ -33,6 +33,21 @@ class TestUnion(JitTestCase): equivalent functions to emulate `checkScript`. """ + def test_check_union_annotation(self): + def test_func(a: Union[int, float], b: Optional[int]): + return 0 + + scripted_func = torch.jit.script(test_func) + graph_rep = str(scripted_func.graph) + code_rep = str(scripted_func.code) + # TS graph IR for Union should be annotated as Union() + FileCheck().check("Union(").check("int?").run(graph_rep) + # Serialized code for Union should be annotated as Union[] + FileCheck().check("Union[").check("Optional[int]").run(code_rep) + self.checkScript(test_func, (5, 6)) + # this shouldn't error out + torch._C.parse_ir(str(scripted_func.graph)) + def test_union_with_scalar_values(self): def fn(x: Union[int, float]) -> str: return "foo" @@ -210,7 +225,7 @@ class TestUnion(JitTestCase): s = fn.graph - FileCheck().check("x : Union[float, int, str]") \ + FileCheck().check("x : Union(float, int, str)") \ .run(s) def test_unions_of_a_single_argument_vanish(self): @@ -230,7 +245,7 @@ class TestUnion(JitTestCase): s = fn.graph - FileCheck().check("x : Union[int, str]") \ + FileCheck().check("x : Union(int, str)") \ .run(s) def test_union_redundant_arguments_are_skipped_optional(self): @@ -240,7 +255,7 @@ class TestUnion(JitTestCase): s = fn.graph - FileCheck().check("x : Union[float, int, NoneType]") \ + FileCheck().check("x : Union(float, int, NoneType)") \ .run(s) def test_union_redundant_arguments_are_skipped_subtyping(self): @@ -250,7 +265,7 @@ class TestUnion(JitTestCase): s = fn.graph - FileCheck().check("x : Union[(int?, int), str]") \ + FileCheck().check("x : Union((int?, int), str)") \ .run(s) def test_union_redundant_arguments_are_skipped_container(self): @@ -260,7 +275,7 @@ class TestUnion(JitTestCase): s = fn.graph - FileCheck().check("x : Union[float[], str[]]") \ + FileCheck().check("x : Union(float[], str[])") \ .run(s) def test_union_argument_order_is_ignored(self): @@ -273,7 +288,7 @@ class TestUnion(JitTestCase): return "foo" for s in (fn1.graph, fn2.graph): - FileCheck().check("x : Union[int, str]") \ + FileCheck().check("x : Union(int, str)") \ .run(s) def test_union_argument_order_is_ignored_container(self): @@ -286,7 +301,7 @@ class TestUnion(JitTestCase): return "foo" for s in (fn1.graph, fn2.graph): - FileCheck().check("x : Union[int[], str[]]") \ + FileCheck().check("x : Union(int[], str[])") \ .run(s) def test_union_T_None_is_equivalent_to_optional_T(self):