mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Allow for heterogenous List and Dict values + Improve container typing algorithm (#57137)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57137 This PR corrects and expands our typing algorithm for unannotated, non-empty dicts and lists. Previously, to verify type correctness for an unannotated, non-empty container, we had gotten the type of the first element in the container, then checked if each following element was a subtype of the first type. That's too restrictive--what if the first element were a subtype of the second element? Instead, we should type the container by getting the smallest common supertype of all the given elements. We need slightly different rules for keys and values in dicts, though: because the set of key types is restricted, finding two key types that cannot be unified should cause an error. On the other hand, the set of value types is not restricted, so we should be able to use `Any` as a valid supertype. We need to keep the set of keys restricted since the keys are used to generate and match schemas. This does not break backwards compatibility, because the default element type is the smallest supertype of all the given types. So, if someone creates an unannotated dict where the keys are all `str` and the values are all `torch.Tensor`, the dict will be inferred to `Dict[str, Tensor]` just like it was before. Empty lists are still typed as `List[torch.Tensor],` and empty dicts are still typed as `Dict[str, Tensor]`. This PR unblocks three engineers on an FB-internal team and improves FX-TorchScript compatibility. Test Plan: Imported from OSS Reviewed By: gmagogsfm Differential Revision: D28231839 Pulled By: ansley fbshipit-source-id: 7297bf239749daa54895add708185c75e6ca5999
This commit is contained in:
parent
ccd0977060
commit
f5c10fdbd3
8 changed files with 393 additions and 90 deletions
|
|
@ -1574,15 +1574,21 @@ inline at::ScalarType scalarTypeFromJitType(const c10::TypePtr& type) {
|
|||
// then t2 will be returned (and vice versa).
|
||||
// Two different tensortypes will return dynamic.
|
||||
// Currently we chose not to support returning a NumberType for a float & int
|
||||
// input because of a lack of operator support for NumberType
|
||||
// input because of a lack of operator support for NumberType.
|
||||
// If `type_hint` is an `InterfaceType`, then we can use that as a
|
||||
// potential supertype for `ClassType`s in the list. Otherwise, we have
|
||||
// no way to find and use some common interface type
|
||||
TORCH_API c10::optional<TypePtr> unifyTypes(
|
||||
const TypePtr& t1,
|
||||
const TypePtr& t2,
|
||||
bool default_to_any = false);
|
||||
bool default_to_any = false,
|
||||
TypePtr type_hint=nullptr);
|
||||
|
||||
TORCH_API c10::optional<TypePtr> unifyTypeList(
|
||||
at::ArrayRef<TypePtr> elements,
|
||||
std::ostream& why_not);
|
||||
std::ostream& why_not,
|
||||
bool default_to_any=false,
|
||||
TypePtr type_hint=nullptr);
|
||||
|
||||
namespace detail {
|
||||
template <typename T>
|
||||
|
|
|
|||
|
|
@ -265,7 +265,7 @@ AnyEnumTypePtr AnyEnumType::get() {
|
|||
return value;
|
||||
}
|
||||
|
||||
c10::optional<TypePtr> unifyTypesImpl(const TypePtr& t1, const TypePtr& t2) {
|
||||
c10::optional<TypePtr> unifyTypesImpl(const TypePtr& t1, const TypePtr& t2, bool default_to_any=false, TypePtr type_hint=nullptr) {
|
||||
// check direct subtyping relation
|
||||
if (t1->isSubtypeOf(t2)) {
|
||||
return t2;
|
||||
|
|
@ -308,7 +308,7 @@ c10::optional<TypePtr> unifyTypesImpl(const TypePtr& t1, const TypePtr& t2) {
|
|||
}
|
||||
std::vector<TypePtr> elements;
|
||||
for (size_t i = 0; i < tuple1->elements().size(); i++) {
|
||||
if (auto elem = unifyTypes(tuple1->elements().at(i), tuple2->elements().at(i))) {
|
||||
if (auto elem = unifyTypes(tuple1->elements().at(i), tuple2->elements().at(i), default_to_any)) {
|
||||
elements.push_back(*elem);
|
||||
} else {
|
||||
return c10::nullopt;
|
||||
|
|
@ -337,11 +337,18 @@ c10::optional<TypePtr> unifyTypesImpl(const TypePtr& t1, const TypePtr& t2) {
|
|||
return t1_unshaped;
|
||||
}
|
||||
|
||||
// Check whether or not `type_hint` is a common parent. This case
|
||||
// could occur if we had two class types that had been annotated with
|
||||
// a common interface
|
||||
if (type_hint && t1->isSubtypeOf(type_hint) && t2->isSubtypeOf(type_hint)) {
|
||||
return type_hint;
|
||||
}
|
||||
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2, bool default_to_any) {
|
||||
auto unified = unifyTypesImpl(t1, t2);
|
||||
c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2, bool default_to_any, TypePtr type_hint) {
|
||||
auto unified = unifyTypesImpl(t1, t2, default_to_any, type_hint);
|
||||
|
||||
if (default_to_any && !unified) {
|
||||
return AnyType::get();
|
||||
|
|
@ -352,7 +359,9 @@ c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2, bool def
|
|||
|
||||
c10::optional<TypePtr> unifyTypeList(
|
||||
at::ArrayRef<TypePtr> elements,
|
||||
std::ostream& why_not) {
|
||||
std::ostream& why_not,
|
||||
bool default_to_any,
|
||||
TypePtr type_hint) {
|
||||
if (elements.size() == 0) {
|
||||
why_not << "Cannot get unified type from empty list";
|
||||
return c10::nullopt;
|
||||
|
|
@ -360,7 +369,7 @@ c10::optional<TypePtr> unifyTypeList(
|
|||
|
||||
TypePtr ret_type = elements.at(0);
|
||||
for (size_t i = 1; i < elements.size() && ret_type; ++i) {
|
||||
auto maybe_unified = unifyTypes(ret_type, elements.at(i));
|
||||
c10::optional<TypePtr> maybe_unified = unifyTypes(ret_type, elements.at(i), default_to_any, type_hint);
|
||||
if (!maybe_unified) {
|
||||
why_not << "Could not unify type list since element " << i << " of type "
|
||||
<< elements.at(i)->repr_str()
|
||||
|
|
@ -368,7 +377,7 @@ c10::optional<TypePtr> unifyTypeList(
|
|||
<< ret_type->repr_str() << ")";
|
||||
return c10::nullopt;
|
||||
}
|
||||
ret_type = maybe_unified.value();
|
||||
ret_type = *maybe_unified;
|
||||
}
|
||||
|
||||
return ret_type;
|
||||
|
|
|
|||
|
|
@ -79,6 +79,9 @@ annotations from the example above one would write:
|
|||
|
||||
* `CHECK: <pattern>`
|
||||
Scans the input until `PATTERN` is found. Fails if the pattern is not found.
|
||||
* `CHECK-NEXT: <pattern>`
|
||||
Scans the input on the line immediately following the previous CHECK until
|
||||
`PATTERN` is found. Fails if the pattern is not found on that line.
|
||||
* `CHECK-NOT: <pattern>`
|
||||
Scans the input and fails if `PATTERN` is found on any line. The scan stops when
|
||||
a match for a next `CHECK` is found.
|
||||
|
|
|
|||
|
|
@ -244,10 +244,10 @@ class TestList(JitTestCase):
|
|||
self.checkScript(fn, ())
|
||||
|
||||
def test_dict_keyword_with_mismatched_annotations(self):
|
||||
# TODO: This fails during function schema matching, so the error
|
||||
# message is not very informative to the user. Change logic so
|
||||
# that the error is thrown at a different time?
|
||||
err_msg = "Arguments for call are not valid"
|
||||
err_msg = r"Dict type annotation `Dict\[int, str\]` did not "\
|
||||
"match the types of the actual dict items"
|
||||
err_msg = r"Dict type annotation `Dict\[int, str\]` did not "\
|
||||
"match the type of an actual key type `str`"
|
||||
highlight_msg = "dict([(\"foo\", 1), (\"bar\", 2), (\"baz\", 3"
|
||||
with self.assertRaisesRegexWithHighlight(RuntimeError, err_msg, highlight_msg):
|
||||
@torch.jit.script
|
||||
|
|
|
|||
|
|
@ -140,7 +140,9 @@ class TestTypesAndAnnotation(JitTestCase):
|
|||
wrong : List[int] = [0.5]
|
||||
return wrong
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Lists must contain only a single type"):
|
||||
with self.assertRaisesRegex(RuntimeError, "List type annotation"
|
||||
r" `List\[int\]` did not match the "
|
||||
"types of the given list elements"):
|
||||
torch.jit.script(wrong_type)
|
||||
|
||||
def test_optional_no_element_type_annotation(self):
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import os
|
|||
import sys
|
||||
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS
|
||||
from collections import namedtuple
|
||||
|
|
@ -73,11 +74,119 @@ class TestTyping(JitTestCase):
|
|||
self.checkScript(test_dict_tensor_key, (dict_a, inp1))
|
||||
self.checkScript(test_dict_tensor_key, (dict_a, inp2))
|
||||
|
||||
def test_dict_types(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "single type"):
|
||||
@torch.jit.script
|
||||
def foo():
|
||||
new_item = {'score': [1.0], 'ys': [1, 2, 3]}
|
||||
def test_list_type_refinement_defaults_to_Any_list_creation(self):
|
||||
def fn(x):
|
||||
tup1 = ("foo", torch.tensor(2))
|
||||
tup2 = ("bar", {"23": torch.tensor(3)})
|
||||
tup3 = ("baz", x)
|
||||
l = list((tup1, tup2)) # noqa: C410
|
||||
l.append(tup3)
|
||||
tup4 = l[0]
|
||||
if torch.jit.isinstance(tup4, Tuple[str, torch.Tensor]):
|
||||
t = tup4[1]
|
||||
if isinstance(t, torch.Tensor):
|
||||
l[0] = (tup4[0], torch.add(t, t))
|
||||
return l
|
||||
|
||||
self.checkScript(fn, (torch.arange(5),))
|
||||
|
||||
graph = torch.jit.script(fn).graph
|
||||
|
||||
print(graph)
|
||||
|
||||
# Check that we're making a `List[Tuple[str, Any]]`
|
||||
FileCheck().check(r"(str, Any)[] = prim::ListConstruct").run(graph)
|
||||
|
||||
def test_list_type_refinement_defaults_to_Any_list_comprehension(self):
|
||||
def fn(x):
|
||||
tup1 = ("foo", torch.tensor(2))
|
||||
tup2 = ("bar", {"23": torch.tensor(3)})
|
||||
tup3 = ("baz", x)
|
||||
l_ = [tup1, tup2]
|
||||
l = [t for t in l_] # noqa: C416
|
||||
l.append(tup3)
|
||||
tup4 = l[0]
|
||||
if torch.jit.isinstance(tup4, Tuple[str, torch.Tensor]):
|
||||
t = tup4[1]
|
||||
if isinstance(t, torch.Tensor):
|
||||
l[0] = (tup4[0], torch.add(t, t))
|
||||
return l
|
||||
|
||||
self.checkScript(fn, (torch.arange(5),))
|
||||
|
||||
graph = torch.jit.script(fn).graph
|
||||
|
||||
print(graph)
|
||||
|
||||
# Check that we're making a `List[Tuple[str, Any]]`
|
||||
FileCheck().check(r"(str, Any)[] = prim::ListConstruct").run(graph)
|
||||
|
||||
def test_list_type_refinement_annotation_element_mismatch(self):
|
||||
def fn():
|
||||
l: List[int] = [1, 2, "foo", 3]
|
||||
return l
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "List type annotation"
|
||||
r" `List\[int\]` did not match the "
|
||||
"types of the given list elements"):
|
||||
torch.jit.script(fn)
|
||||
|
||||
def test_dict_type_refinement_defaults_to_Any_dict_creation(self):
|
||||
def fn(x):
|
||||
d = dict(foo=torch.tensor(2),
|
||||
bar={"23": torch.tensor(3)})
|
||||
d["baz"] = x
|
||||
t = d["foo"]
|
||||
if isinstance(t, torch.Tensor):
|
||||
d["bar"] = torch.add(t, t)
|
||||
return d
|
||||
|
||||
self.checkScript(fn, (torch.arange(5),))
|
||||
|
||||
graph = torch.jit.script(fn).graph
|
||||
|
||||
FileCheck().check(r"Dict(str, Any) = prim::DictConstruct").run(graph)
|
||||
|
||||
def test_dict_type_refinement_defaults_to_Any_dict_comprehension(self):
|
||||
def fn(x):
|
||||
d = {"foo": torch.tensor(2),
|
||||
"bar": {"23": torch.tensor(3)}}
|
||||
d["baz"] = x
|
||||
t = d["foo"]
|
||||
if isinstance(t, torch.Tensor):
|
||||
d["bar"] = torch.add(t, t)
|
||||
return d
|
||||
|
||||
self.checkScript(fn, (torch.arange(5),))
|
||||
|
||||
graph = torch.jit.script(fn).graph
|
||||
|
||||
FileCheck().check("Dict(str, Any) = prim::DictConstruct").run(graph)
|
||||
|
||||
def test_dict_type_refinement_annotation_key_mismatch(self):
|
||||
def fn():
|
||||
l1 = [1, 2, "foo", 3]
|
||||
l2 = ["foo", "bar", "baz", "qux"]
|
||||
d: Dict[int, str] = {k : v for k, v in zip(l1, l2)}
|
||||
return l
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Dict type annotation"
|
||||
r" `Dict\[int, str\]` did not match"
|
||||
" the type of an actual key type"):
|
||||
torch.jit.script(fn)
|
||||
|
||||
def test_dict_type_refinement_annotation_value_mismatch(self):
|
||||
def fn():
|
||||
l1 = ["foo", "bar", "baz", "qux"]
|
||||
l2 = [1, 2, "foo", 3]
|
||||
d: Dict[str, int] = {k : v for k, v in zip(l1, l2)}
|
||||
return l
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Dict type annotation"
|
||||
r" `Dict\[str, int\]` did not match"
|
||||
" the type of an actual value "
|
||||
"type"):
|
||||
torch.jit.script(fn)
|
||||
|
||||
def test_dict_invalid_annotations(self):
|
||||
# Check for invalid value type annotation
|
||||
|
|
@ -200,16 +309,6 @@ class TestTyping(JitTestCase):
|
|||
self.checkScript(fn, [])
|
||||
self.checkScript(fn2, (torch.ones(2, 2),))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Could not unify"):
|
||||
@torch.jit.script
|
||||
def fn():
|
||||
return [1, 1.2]
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Could not unify"):
|
||||
@torch.jit.script
|
||||
def fn():
|
||||
return [1, torch.ones(1, 2)]
|
||||
|
||||
# to avoid defining sum_list in multiple tests
|
||||
def get_sum_list_fn(self):
|
||||
def sum_list(a):
|
||||
|
|
|
|||
|
|
@ -10584,7 +10584,10 @@ dedent """
|
|||
def f5(a):
|
||||
torch.cat([3])
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, 'Lists must contain only a single type'):
|
||||
with self.assertRaisesRegex(RuntimeError, r'Expected a value of'
|
||||
r' type \'List\[int\]\' for argument'
|
||||
r' \'size\' but instead found type '
|
||||
r'\'List\[Any\]\''):
|
||||
@torch.jit.script
|
||||
def f6(a):
|
||||
a.expand(size=[3, [4]])
|
||||
|
|
|
|||
|
|
@ -1321,13 +1321,51 @@ struct to_ir {
|
|||
pushFrame(comprehension_block);
|
||||
WithInsertPoint guard(comprehension_block);
|
||||
auto emit_body = [&]() {
|
||||
auto comprehension_out = emitExpr(lc.elt());
|
||||
Value* out = emitExpr(lc.elt());
|
||||
|
||||
// If we didn't have a type annotation, the type of the list would
|
||||
// be set to `Tensor`. We don't want to unify this default type
|
||||
// with the actual elements in the list, so let the type begin as
|
||||
// the first element in the list
|
||||
if (!type_set) {
|
||||
list_value->setType(ListType::create(comprehension_out->type()));
|
||||
list_value->setType(ListType::create(out->type()));
|
||||
type_set = true;
|
||||
}
|
||||
|
||||
ListTypePtr lt = list_value->type()->expect<ListType>();
|
||||
|
||||
const TypePtr element_type_hint =
|
||||
type_hint ? type_hint->expect<ListType>()->getElementType() : nullptr;
|
||||
|
||||
auto unified = unifyTypes(
|
||||
lt->getElementType(),
|
||||
out->type(),
|
||||
/*default_to_any=*/true,
|
||||
element_type_hint);
|
||||
|
||||
if (lt->getElementType() != AnyType::get() &&
|
||||
*unified == AnyType::get()) {
|
||||
TORCH_WARN(
|
||||
"List consists of heterogeneous types, which means",
|
||||
" that it has been typed as `List[Any]`. To use "
|
||||
"any of the values in the List, it will be "
|
||||
"necessary to add an `assert isinstance` statement "
|
||||
"before first use to trigger type refinement. The first ",
|
||||
"non-matching element was typed as ",
|
||||
out->type()->repr_str(),
|
||||
", while the elements before it "
|
||||
"were ",
|
||||
lt->getElementType()->repr_str(),
|
||||
"\n",
|
||||
lc.range().str());
|
||||
}
|
||||
|
||||
if (!type_hint) {
|
||||
list_value->setType(ListType::create(*unified));
|
||||
}
|
||||
|
||||
NamedValue self = NamedValue(loc, "self", list_value);
|
||||
NamedValue input = NamedValue(loc, "", comprehension_out);
|
||||
NamedValue input = NamedValue(loc, "", out);
|
||||
emitBuiltinCall(loc, *graph, aten::append, {input}, {}, self);
|
||||
};
|
||||
emitFor(targets_list, itrs, loc, emit_body);
|
||||
|
|
@ -1366,10 +1404,88 @@ struct to_ir {
|
|||
auto emit_body = [&]() {
|
||||
auto k = emitExpr(dc.key());
|
||||
auto v = emitExpr(dc.value());
|
||||
|
||||
// Make sure that any key and value types are subtypes of the
|
||||
// annotatated key/value types
|
||||
if (type_hint) {
|
||||
DictTypePtr dict_type_hint = type_hint->expect<DictType>();
|
||||
|
||||
std::stringstream ss;
|
||||
std::stringstream err;
|
||||
|
||||
bool is_key_subtype =
|
||||
k->type()->isSubtypeOfExt(dict_type_hint->getKeyType(), &ss);
|
||||
|
||||
if (!is_key_subtype) {
|
||||
err << "Dict type annotation `" << dict_type_hint->repr_str()
|
||||
<< "` did not match the "
|
||||
<< "type of an actual key type `" << k->type()->repr_str()
|
||||
<< "`\n"
|
||||
<< ss.str();
|
||||
}
|
||||
|
||||
ss.str(std::string());
|
||||
bool is_value_subtype =
|
||||
v->type()->isSubtypeOfExt(dict_type_hint->getValueType(), &ss);
|
||||
|
||||
if (!is_value_subtype) {
|
||||
err << "Dict type annotation `" << dict_type_hint->repr_str()
|
||||
<< "` did not match the "
|
||||
<< "type of an actual value type `" << v->type()->repr_str()
|
||||
<< "`\n"
|
||||
<< ss.str();
|
||||
}
|
||||
|
||||
if (!is_key_subtype || !is_value_subtype) {
|
||||
throw ErrorReport(dc) << err.str();
|
||||
}
|
||||
}
|
||||
|
||||
// If we didn't have a type annotation, the type of the dict would
|
||||
// be set to `(str, Tensor)`. We don't want to unify this default
|
||||
// type with the actual elements in the dict, so let the type
|
||||
// begin as the first element in the dict
|
||||
if (!type_set) {
|
||||
dict_value->setType(DictType::create(k->type(), v->type()));
|
||||
type_set = true;
|
||||
}
|
||||
|
||||
DictTypePtr dt = dict_value->type()->expect<DictType>();
|
||||
|
||||
const TypePtr value_type_hint =
|
||||
type_hint ? type_hint->expect<DictType>()->getKeyType() : nullptr;
|
||||
|
||||
c10::optional<TypePtr> unified = unifyTypes(
|
||||
dt->getValueType(),
|
||||
v->type(),
|
||||
/*default_to_any=*/true,
|
||||
value_type_hint);
|
||||
|
||||
// Warn the user if we inferred the type of the values to be `Any`
|
||||
// even though the annotation was something else
|
||||
if (dt->getValueType() != AnyType::get() && *unified == AnyType::get()) {
|
||||
TORCH_WARN(
|
||||
"Dict consists of heterogeneous types, which means",
|
||||
" that it has been typed as `Dict[str, Any]`. To use "
|
||||
"any of the values in the Dict, it will be "
|
||||
"necessary to add an `assert isinstance` statement "
|
||||
"before first use to trigger type refinement. The first ",
|
||||
"non-matching element was typed as ",
|
||||
v->type()->repr_str(),
|
||||
", while the elements before it "
|
||||
"were ",
|
||||
dt->getValueType()->repr_str(),
|
||||
"\n",
|
||||
dc.range().str());
|
||||
}
|
||||
|
||||
// We only want to set `dict_value` if we don't have a type hint
|
||||
// to allow for the case that `*unified` is a subtype of
|
||||
// the value type given by `type_hint`
|
||||
if (!type_hint) {
|
||||
dict_value->setType(DictType::create(k->type(), *unified));
|
||||
}
|
||||
|
||||
NamedValue self = NamedValue(loc, "self", dict_value);
|
||||
NamedValue input_k = NamedValue(loc, "", k);
|
||||
NamedValue input_v = NamedValue(loc, "", v);
|
||||
|
|
@ -3534,6 +3650,66 @@ struct to_ir {
|
|||
->call(tree->range(), method, named_values, {}, 0));
|
||||
}
|
||||
|
||||
Value* emitListLiteral(ListLiteral ll, TypePtr type_hint) {
|
||||
auto values = getValues(ll.inputs(), /*maybe_unpack=*/true);
|
||||
|
||||
// Determine the element type of the list. If we have a type hint
|
||||
// of `List[T]`, use `T`. If the list is non-empty, find the
|
||||
// greatest common supertype of all the list elements (defaulting to
|
||||
// `Any` as a catch-all supertype). Assume `[]` is `List[Tensor]`
|
||||
TypePtr elem_type = TensorType::get();
|
||||
|
||||
if (type_hint) {
|
||||
if (type_hint->kind() == TypeKind::ListType) {
|
||||
elem_type = type_hint->expectRef<ListType>().getElementType();
|
||||
} else {
|
||||
// If the type hint was not `List[T]`, throw an error
|
||||
throw ErrorReport(ll) << "Expected a List type hint but instead got "
|
||||
<< type_hint->repr_str();
|
||||
}
|
||||
}
|
||||
|
||||
if (!values.empty()) {
|
||||
auto types = fmap(values, [](const Value* v) { return v->type(); });
|
||||
|
||||
std::stringstream nowhere; // never used
|
||||
|
||||
const TypePtr element_type_hint =
|
||||
type_hint ? type_hint->expect<ListType>()->getElementType() : nullptr;
|
||||
|
||||
c10::optional<TypePtr> unified = unifyTypeList(
|
||||
types, nowhere, /*default_to_any=*/true, element_type_hint);
|
||||
|
||||
if (!type_hint && *unified == AnyType::get()) {
|
||||
TORCH_WARN(
|
||||
"List consists of heterogeneous types, which means",
|
||||
" that it has been typed as `List[Any]`. To use "
|
||||
"any of the values in the List, it will be "
|
||||
"necessary to add an `assert isinstance` statement "
|
||||
"before first use to trigger type refinement. \n",
|
||||
ll.range().str());
|
||||
}
|
||||
|
||||
if (type_hint && !(*unified)->isSubtypeOf(elem_type)) {
|
||||
throw ErrorReport(ll)
|
||||
<< "List type annotation `" << type_hint->repr_str()
|
||||
<< "` did not match the types of the given list elements,"
|
||||
<< " which were unified to " << (*unified)->repr_str();
|
||||
}
|
||||
|
||||
// We only want to set `elem_type` if we don't have a type hint
|
||||
// to allow for the case that `*unified` is a subtype of
|
||||
// `type_hint`
|
||||
if (!type_hint) {
|
||||
elem_type = *unified;
|
||||
}
|
||||
}
|
||||
|
||||
Value* result =
|
||||
graph->insertNode(graph->createList(elem_type, values))->output();
|
||||
return result;
|
||||
}
|
||||
|
||||
Value* emitSimpleExpr(
|
||||
const TreeRef& tree,
|
||||
const TypePtr& type_hint = nullptr) {
|
||||
|
|
@ -3616,46 +3792,7 @@ struct to_ir {
|
|||
} break;
|
||||
case TK_LIST_LITERAL: {
|
||||
auto ll = ListLiteral(tree);
|
||||
auto values = getValues(ll.inputs(), /*maybe_unpack=*/true);
|
||||
|
||||
// determine the element type of the list
|
||||
// if we have a type hint of List[T], use T
|
||||
// if the list is non-empty use type_of(list[0])
|
||||
// otherwise assume it is List[Tensor]
|
||||
TypePtr elem_type = TensorType::get();
|
||||
if (type_hint) {
|
||||
if (type_hint->kind() == TypeKind::ListType) {
|
||||
elem_type = type_hint->expectRef<ListType>().getElementType();
|
||||
} else {
|
||||
// If the type hint was not a List[T] throw an error
|
||||
throw ErrorReport(tree)
|
||||
<< "Expected a List type hint but instead got "
|
||||
<< type_hint->repr_str();
|
||||
}
|
||||
} else if (!values.empty()) {
|
||||
std::stringstream ss;
|
||||
auto types = fmap(values, [](const Value* v) { return v->type(); });
|
||||
auto maybe_elem_type = unifyTypeList(types, ss);
|
||||
if (!maybe_elem_type) {
|
||||
throw ErrorReport(tree) << "Lists must contain only a single type\n"
|
||||
<< ss.str();
|
||||
}
|
||||
elem_type = maybe_elem_type.value();
|
||||
}
|
||||
|
||||
for (auto v : values) {
|
||||
std::stringstream ss;
|
||||
if (!v->type()->isSubtypeOfExt(elem_type, &ss)) {
|
||||
throw ErrorReport(tree)
|
||||
<< "Lists must contain only a single type, expected: "
|
||||
<< elem_type->repr_str() << " but found "
|
||||
<< v->type()->repr_str() << " instead.\n"
|
||||
<< ss.str();
|
||||
}
|
||||
}
|
||||
Value* result =
|
||||
graph->insertNode(graph->createList(elem_type, values))->output();
|
||||
return result;
|
||||
return emitListLiteral(ll, type_hint);
|
||||
} break;
|
||||
case TK_TUPLE_LITERAL: {
|
||||
auto ll = TupleLiteral(tree);
|
||||
|
|
@ -3690,24 +3827,68 @@ struct to_ir {
|
|||
}
|
||||
AT_ASSERT(key_type != nullptr && value_type != nullptr);
|
||||
|
||||
auto checkTypeOfValues = [](const TypePtr& type,
|
||||
const char* what,
|
||||
const std::vector<Value*>& values,
|
||||
TreeList trees) {
|
||||
for (size_t i = 0, N = values.size(); i < N; ++i) {
|
||||
std::stringstream ss;
|
||||
if (!values[i]->type()->isSubtypeOfExt(type, &ss)) {
|
||||
throw ErrorReport(trees[i])
|
||||
<< "Dict " << what
|
||||
<< " must contain only a single type, expected: "
|
||||
<< type->repr_str() << " but found "
|
||||
<< values[i]->type()->repr_str() << " instead.\n"
|
||||
<< ss.str();
|
||||
for (size_t i = 0; i < keys.size(); ++i) {
|
||||
std::stringstream ss;
|
||||
if (!keys[i]->type()->isSubtypeOfExt(key_type, &ss)) {
|
||||
throw ErrorReport(key_trees[i])
|
||||
<< "Dict keys must contain "
|
||||
<< "only a single type. Expected: " << key_type->repr_str()
|
||||
<< " but found " << keys[i]->type()->repr_str() << " instead.\n"
|
||||
<< ss.str();
|
||||
}
|
||||
}
|
||||
|
||||
if (!values.empty()) {
|
||||
auto types = fmap(values, [](const Value* v) { return v->type(); });
|
||||
|
||||
std::stringstream nowhere; // never used
|
||||
|
||||
const TypePtr value_type_hint =
|
||||
type_hint ? type_hint->expect<DictType>()->getKeyType() : nullptr;
|
||||
|
||||
c10::optional<TypePtr> unified = unifyTypeList(
|
||||
types,
|
||||
/*why_not=*/nowhere,
|
||||
/*default_to_any=*/true,
|
||||
value_type_hint);
|
||||
|
||||
if (!type_hint && *unified == AnyType::get()) {
|
||||
TORCH_WARN(
|
||||
"Dict values consist of heterogeneous types, which "
|
||||
"means that they have been typed as `Any`. To use "
|
||||
"any of the values in the Dist, it will be "
|
||||
"necessary to add an `assert isinstance` statement "
|
||||
"before first use to trigger type refinement. \n",
|
||||
dl.range().str());
|
||||
}
|
||||
|
||||
if (type_hint) {
|
||||
TypePtr value_type_hint =
|
||||
type_hint->expect<DictType>()->getValueType();
|
||||
for (size_t i = 0; i < types.size(); ++i) {
|
||||
TORCH_CHECK(
|
||||
types[i]->isSubtypeOf(value_type_hint),
|
||||
"Type "
|
||||
"hint for dict was",
|
||||
type_hint->repr_str(),
|
||||
"but the value ",
|
||||
"at index ",
|
||||
i,
|
||||
" has type ",
|
||||
types[i]->repr_str(),
|
||||
", which is not a valid"
|
||||
" subtype of ",
|
||||
value_type_hint->repr_str());
|
||||
}
|
||||
}
|
||||
};
|
||||
checkTypeOfValues(key_type, "keys", keys, key_trees);
|
||||
checkTypeOfValues(value_type, "values", values, value_trees);
|
||||
|
||||
// We only want to set `value_type` if we don't have a type
|
||||
// hint to allow for the case that `*unified` is a subtype of
|
||||
// the value type given by `type_hint`
|
||||
if (!type_hint) {
|
||||
value_type = *unified;
|
||||
}
|
||||
}
|
||||
|
||||
return graph
|
||||
->insertNode(graph->createDict(key_type, value_type, keys, values))
|
||||
|
|
|
|||
Loading…
Reference in a new issue