mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Output UnionType str rep with () instead of [] (#69502)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69502 Test Plan: Imported from OSS Reviewed By: pbelevich Differential Revision: D32902781 Pulled By: tugsbayasgalan fbshipit-source-id: 67a73b209575437477cdbd3eb8f685019709e99c
This commit is contained in:
parent
a8232ee1bc
commit
829b49b867
3 changed files with 32 additions and 12 deletions
|
|
@ -147,7 +147,9 @@ struct TORCH_API UnionType : public Type {
|
|||
protected:
|
||||
explicit UnionType(std::vector<TypePtr> types, TypeKind kind=TypeKind::UnionType);
|
||||
std::string annotation_str_impl(TypePrinter printer = nullptr) const override;
|
||||
std::string unionStr(TypePrinter printer = nullptr, bool is_annotation_str = false) const;
|
||||
std::string unionStr(
|
||||
TypePrinter printer = nullptr,
|
||||
bool is_annotation_str = false) const;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
bool has_free_variables_;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||
|
|
|
|||
|
|
@ -1099,8 +1099,8 @@ bool UnionType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
|
|||
});
|
||||
}
|
||||
|
||||
|
||||
std::string UnionType::unionStr(TypePrinter printer, bool is_annotation_str) const {
|
||||
std::string UnionType::unionStr(TypePrinter printer, bool is_annotation_str)
|
||||
const {
|
||||
std::stringstream ss;
|
||||
|
||||
bool can_hold_numbertype = this->canHoldType(*NumberType::get());
|
||||
|
|
@ -1116,7 +1116,10 @@ std::string UnionType::unionStr(TypePrinter printer, bool is_annotation_str) con
|
|||
return false;
|
||||
};
|
||||
|
||||
ss << "Union[";
|
||||
std::string open_delimeter = is_annotation_str ? "[" : "(";
|
||||
std::string close_delimeter = is_annotation_str ? "]" : ")";
|
||||
|
||||
ss << "Union" + open_delimeter;
|
||||
bool printed = false;
|
||||
for (size_t i = 0; i < types_.size(); ++i) {
|
||||
if (!can_hold_numbertype || !is_numbertype(types_[i])) {
|
||||
|
|
@ -1141,7 +1144,7 @@ std::string UnionType::unionStr(TypePrinter printer, bool is_annotation_str) con
|
|||
ss << NumberType::get()->str();
|
||||
}
|
||||
}
|
||||
ss << "]";
|
||||
ss << close_delimeter;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -33,6 +33,21 @@ class TestUnion(JitTestCase):
|
|||
equivalent functions to emulate `checkScript`.
|
||||
"""
|
||||
|
||||
def test_check_union_annotation(self):
|
||||
def test_func(a: Union[int, float], b: Optional[int]):
|
||||
return 0
|
||||
|
||||
scripted_func = torch.jit.script(test_func)
|
||||
graph_rep = str(scripted_func.graph)
|
||||
code_rep = str(scripted_func.code)
|
||||
# TS graph IR for Union should be annotated as Union()
|
||||
FileCheck().check("Union(").check("int?").run(graph_rep)
|
||||
# Serialized code for Union should be annotated as Union[]
|
||||
FileCheck().check("Union[").check("Optional[int]").run(code_rep)
|
||||
self.checkScript(test_func, (5, 6))
|
||||
# this shouldn't error out
|
||||
torch._C.parse_ir(str(scripted_func.graph))
|
||||
|
||||
def test_union_with_scalar_values(self):
|
||||
def fn(x: Union[int, float]) -> str:
|
||||
return "foo"
|
||||
|
|
@ -210,7 +225,7 @@ class TestUnion(JitTestCase):
|
|||
|
||||
s = fn.graph
|
||||
|
||||
FileCheck().check("x : Union[float, int, str]") \
|
||||
FileCheck().check("x : Union(float, int, str)") \
|
||||
.run(s)
|
||||
|
||||
def test_unions_of_a_single_argument_vanish(self):
|
||||
|
|
@ -230,7 +245,7 @@ class TestUnion(JitTestCase):
|
|||
|
||||
s = fn.graph
|
||||
|
||||
FileCheck().check("x : Union[int, str]") \
|
||||
FileCheck().check("x : Union(int, str)") \
|
||||
.run(s)
|
||||
|
||||
def test_union_redundant_arguments_are_skipped_optional(self):
|
||||
|
|
@ -240,7 +255,7 @@ class TestUnion(JitTestCase):
|
|||
|
||||
s = fn.graph
|
||||
|
||||
FileCheck().check("x : Union[float, int, NoneType]") \
|
||||
FileCheck().check("x : Union(float, int, NoneType)") \
|
||||
.run(s)
|
||||
|
||||
def test_union_redundant_arguments_are_skipped_subtyping(self):
|
||||
|
|
@ -250,7 +265,7 @@ class TestUnion(JitTestCase):
|
|||
|
||||
s = fn.graph
|
||||
|
||||
FileCheck().check("x : Union[(int?, int), str]") \
|
||||
FileCheck().check("x : Union((int?, int), str)") \
|
||||
.run(s)
|
||||
|
||||
def test_union_redundant_arguments_are_skipped_container(self):
|
||||
|
|
@ -260,7 +275,7 @@ class TestUnion(JitTestCase):
|
|||
|
||||
s = fn.graph
|
||||
|
||||
FileCheck().check("x : Union[float[], str[]]") \
|
||||
FileCheck().check("x : Union(float[], str[])") \
|
||||
.run(s)
|
||||
|
||||
def test_union_argument_order_is_ignored(self):
|
||||
|
|
@ -273,7 +288,7 @@ class TestUnion(JitTestCase):
|
|||
return "foo"
|
||||
|
||||
for s in (fn1.graph, fn2.graph):
|
||||
FileCheck().check("x : Union[int, str]") \
|
||||
FileCheck().check("x : Union(int, str)") \
|
||||
.run(s)
|
||||
|
||||
def test_union_argument_order_is_ignored_container(self):
|
||||
|
|
@ -286,7 +301,7 @@ class TestUnion(JitTestCase):
|
|||
return "foo"
|
||||
|
||||
for s in (fn1.graph, fn2.graph):
|
||||
FileCheck().check("x : Union[int[], str[]]") \
|
||||
FileCheck().check("x : Union(int[], str[])") \
|
||||
.run(s)
|
||||
|
||||
def test_union_T_None_is_equivalent_to_optional_T(self):
|
||||
|
|
|
|||
Loading…
Reference in a new issue