Output UnionType str rep with () instead of [] (#69502)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69502

Test Plan: Imported from OSS

Reviewed By: pbelevich

Differential Revision: D32902781

Pulled By: tugsbayasgalan

fbshipit-source-id: 67a73b209575437477cdbd3eb8f685019709e99c
This commit is contained in:
Tugsbayasgalan (Tugsuu) Manlaibaatar 2021-12-07 14:15:23 -08:00 committed by Facebook GitHub Bot
parent a8232ee1bc
commit 829b49b867
3 changed files with 32 additions and 12 deletions

View file

@ -147,7 +147,9 @@ struct TORCH_API UnionType : public Type {
protected:
explicit UnionType(std::vector<TypePtr> types, TypeKind kind=TypeKind::UnionType);
std::string annotation_str_impl(TypePrinter printer = nullptr) const override;
std::string unionStr(TypePrinter printer = nullptr, bool is_annotation_str = false) const;
std::string unionStr(
TypePrinter printer = nullptr,
bool is_annotation_str = false) const;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
bool has_free_variables_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)

View file

@ -1099,8 +1099,8 @@ bool UnionType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
});
}
std::string UnionType::unionStr(TypePrinter printer, bool is_annotation_str) const {
std::string UnionType::unionStr(TypePrinter printer, bool is_annotation_str)
const {
std::stringstream ss;
bool can_hold_numbertype = this->canHoldType(*NumberType::get());
@ -1116,7 +1116,10 @@ std::string UnionType::unionStr(TypePrinter printer, bool is_annotation_str) con
return false;
};
ss << "Union[";
std::string open_delimeter = is_annotation_str ? "[" : "(";
std::string close_delimeter = is_annotation_str ? "]" : ")";
ss << "Union" + open_delimeter;
bool printed = false;
for (size_t i = 0; i < types_.size(); ++i) {
if (!can_hold_numbertype || !is_numbertype(types_[i])) {
@ -1141,7 +1144,7 @@ std::string UnionType::unionStr(TypePrinter printer, bool is_annotation_str) con
ss << NumberType::get()->str();
}
}
ss << "]";
ss << close_delimeter;
return ss.str();
}

View file

@ -33,6 +33,21 @@ class TestUnion(JitTestCase):
equivalent functions to emulate `checkScript`.
"""
def test_check_union_annotation(self):
def test_func(a: Union[int, float], b: Optional[int]):
return 0
scripted_func = torch.jit.script(test_func)
graph_rep = str(scripted_func.graph)
code_rep = str(scripted_func.code)
# TS graph IR for Union should be annotated as Union()
FileCheck().check("Union(").check("int?").run(graph_rep)
# Serialized code for Union should be annotated as Union[]
FileCheck().check("Union[").check("Optional[int]").run(code_rep)
self.checkScript(test_func, (5, 6))
# this shouldn't error out
torch._C.parse_ir(str(scripted_func.graph))
def test_union_with_scalar_values(self):
def fn(x: Union[int, float]) -> str:
return "foo"
@ -210,7 +225,7 @@ class TestUnion(JitTestCase):
s = fn.graph
FileCheck().check("x : Union[float, int, str]") \
FileCheck().check("x : Union(float, int, str)") \
.run(s)
def test_unions_of_a_single_argument_vanish(self):
@ -230,7 +245,7 @@ class TestUnion(JitTestCase):
s = fn.graph
FileCheck().check("x : Union[int, str]") \
FileCheck().check("x : Union(int, str)") \
.run(s)
def test_union_redundant_arguments_are_skipped_optional(self):
@ -240,7 +255,7 @@ class TestUnion(JitTestCase):
s = fn.graph
FileCheck().check("x : Union[float, int, NoneType]") \
FileCheck().check("x : Union(float, int, NoneType)") \
.run(s)
def test_union_redundant_arguments_are_skipped_subtyping(self):
@ -250,7 +265,7 @@ class TestUnion(JitTestCase):
s = fn.graph
FileCheck().check("x : Union[(int?, int), str]") \
FileCheck().check("x : Union((int?, int), str)") \
.run(s)
def test_union_redundant_arguments_are_skipped_container(self):
@ -260,7 +275,7 @@ class TestUnion(JitTestCase):
s = fn.graph
FileCheck().check("x : Union[float[], str[]]") \
FileCheck().check("x : Union(float[], str[])") \
.run(s)
def test_union_argument_order_is_ignored(self):
@ -273,7 +288,7 @@ class TestUnion(JitTestCase):
return "foo"
for s in (fn1.graph, fn2.graph):
FileCheck().check("x : Union[int, str]") \
FileCheck().check("x : Union(int, str)") \
.run(s)
def test_union_argument_order_is_ignored_container(self):
@ -286,7 +301,7 @@ class TestUnion(JitTestCase):
return "foo"
for s in (fn1.graph, fn2.graph):
FileCheck().check("x : Union[int[], str[]]") \
FileCheck().check("x : Union(int[], str[])") \
.run(s)
def test_union_T_None_is_equivalent_to_optional_T(self):