diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 9fafdfd896f..d733fbd2da5 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -1574,15 +1574,21 @@ inline at::ScalarType scalarTypeFromJitType(const c10::TypePtr& type) { // then t2 will be returned (and vice versa). // Two different tensortypes will return dynamic. // Currently we chose not to support returning a NumberType for a float & int -// input because of a lack of operator support for NumberType +// input because of a lack of operator support for NumberType. +// If `type_hint` is an `InterfaceType`, then we can use that as a +// potential supertype for `ClassType`s in the list. Otherwise, we have +// no way to find and use some common interface type TORCH_API c10::optional unifyTypes( const TypePtr& t1, const TypePtr& t2, - bool default_to_any = false); + bool default_to_any = false, + TypePtr type_hint=nullptr); TORCH_API c10::optional unifyTypeList( at::ArrayRef elements, - std::ostream& why_not); + std::ostream& why_not, + bool default_to_any=false, + TypePtr type_hint=nullptr); namespace detail { template diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index fa467b7a005..e8cffebcfaf 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -265,7 +265,7 @@ AnyEnumTypePtr AnyEnumType::get() { return value; } -c10::optional unifyTypesImpl(const TypePtr& t1, const TypePtr& t2) { +c10::optional unifyTypesImpl(const TypePtr& t1, const TypePtr& t2, bool default_to_any=false, TypePtr type_hint=nullptr) { // check direct subtyping relation if (t1->isSubtypeOf(t2)) { return t2; @@ -308,7 +308,7 @@ c10::optional unifyTypesImpl(const TypePtr& t1, const TypePtr& t2) { } std::vector elements; for (size_t i = 0; i < tuple1->elements().size(); i++) { - if (auto elem = unifyTypes(tuple1->elements().at(i), tuple2->elements().at(i))) { + if (auto elem = unifyTypes(tuple1->elements().at(i), tuple2->elements().at(i), default_to_any)) { elements.push_back(*elem); } else { return c10::nullopt; @@ -337,11 +337,18 @@ c10::optional unifyTypesImpl(const TypePtr& t1, const TypePtr& t2) { return t1_unshaped; } + // Check whether or not `type_hint` is a common parent. This case + // could occur if we had two class types that had been annotated with + // a common interface + if (type_hint && t1->isSubtypeOf(type_hint) && t2->isSubtypeOf(type_hint)) { + return type_hint; + } + return c10::nullopt; } -c10::optional unifyTypes(const TypePtr& t1, const TypePtr& t2, bool default_to_any) { - auto unified = unifyTypesImpl(t1, t2); +c10::optional unifyTypes(const TypePtr& t1, const TypePtr& t2, bool default_to_any, TypePtr type_hint) { + auto unified = unifyTypesImpl(t1, t2, default_to_any, type_hint); if (default_to_any && !unified) { return AnyType::get(); @@ -352,7 +359,9 @@ c10::optional unifyTypes(const TypePtr& t1, const TypePtr& t2, bool def c10::optional unifyTypeList( at::ArrayRef elements, - std::ostream& why_not) { + std::ostream& why_not, + bool default_to_any, + TypePtr type_hint) { if (elements.size() == 0) { why_not << "Cannot get unified type from empty list"; return c10::nullopt; @@ -360,7 +369,7 @@ c10::optional unifyTypeList( TypePtr ret_type = elements.at(0); for (size_t i = 1; i < elements.size() && ret_type; ++i) { - auto maybe_unified = unifyTypes(ret_type, elements.at(i)); + c10::optional maybe_unified = unifyTypes(ret_type, elements.at(i), default_to_any, type_hint); if (!maybe_unified) { why_not << "Could not unify type list since element " << i << " of type " << elements.at(i)->repr_str() @@ -368,7 +377,7 @@ c10::optional unifyTypeList( << ret_type->repr_str() << ")"; return c10::nullopt; } - ret_type = maybe_unified.value(); + ret_type = *maybe_unified; } return ret_type; diff --git a/test/HowToWriteTestsUsingFileCheck.md b/test/HowToWriteTestsUsingFileCheck.md index 429d7a06b48..0795c23002a 100644 --- a/test/HowToWriteTestsUsingFileCheck.md +++ b/test/HowToWriteTestsUsingFileCheck.md @@ -79,6 +79,9 @@ annotations from the example above one would write: * `CHECK: ` Scans the input until `PATTERN` is found. Fails if the pattern is not found. +* `CHECK-NEXT: ` + Scans the input on the line immediately following the previous CHECK until + `PATTERN` is found. Fails if the pattern is not found on that line. * `CHECK-NOT: ` Scans the input and fails if `PATTERN` is found on any line. The scan stops when a match for a next `CHECK` is found. diff --git a/test/jit/test_list_dict.py b/test/jit/test_list_dict.py index a7d30dae874..d8434515291 100644 --- a/test/jit/test_list_dict.py +++ b/test/jit/test_list_dict.py @@ -244,10 +244,10 @@ class TestList(JitTestCase): self.checkScript(fn, ()) def test_dict_keyword_with_mismatched_annotations(self): - # TODO: This fails during function schema matching, so the error - # message is not very informative to the user. Change logic so - # that the error is thrown at a different time? - err_msg = "Arguments for call are not valid" + err_msg = r"Dict type annotation `Dict\[int, str\]` did not "\ + "match the types of the actual dict items" + err_msg = r"Dict type annotation `Dict\[int, str\]` did not "\ + "match the type of an actual key type `str`" highlight_msg = "dict([(\"foo\", 1), (\"bar\", 2), (\"baz\", 3" with self.assertRaisesRegexWithHighlight(RuntimeError, err_msg, highlight_msg): @torch.jit.script diff --git a/test/jit/test_types.py b/test/jit/test_types.py index e7edc4734b4..5da4efde374 100644 --- a/test/jit/test_types.py +++ b/test/jit/test_types.py @@ -140,7 +140,9 @@ class TestTypesAndAnnotation(JitTestCase): wrong : List[int] = [0.5] return wrong - with self.assertRaisesRegex(RuntimeError, "Lists must contain only a single type"): + with self.assertRaisesRegex(RuntimeError, "List type annotation" + r" `List\[int\]` did not match the " + "types of the given list elements"): torch.jit.script(wrong_type) def test_optional_no_element_type_annotation(self): diff --git a/test/jit/test_typing.py b/test/jit/test_typing.py index f6797527556..f60f25f782e 100644 --- a/test/jit/test_typing.py +++ b/test/jit/test_typing.py @@ -2,6 +2,7 @@ import os import sys import torch +from torch.testing import FileCheck from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.common_utils import IS_WINDOWS from collections import namedtuple @@ -73,11 +74,119 @@ class TestTyping(JitTestCase): self.checkScript(test_dict_tensor_key, (dict_a, inp1)) self.checkScript(test_dict_tensor_key, (dict_a, inp2)) - def test_dict_types(self): - with self.assertRaisesRegex(RuntimeError, "single type"): - @torch.jit.script - def foo(): - new_item = {'score': [1.0], 'ys': [1, 2, 3]} + def test_list_type_refinement_defaults_to_Any_list_creation(self): + def fn(x): + tup1 = ("foo", torch.tensor(2)) + tup2 = ("bar", {"23": torch.tensor(3)}) + tup3 = ("baz", x) + l = list((tup1, tup2)) # noqa: C410 + l.append(tup3) + tup4 = l[0] + if torch.jit.isinstance(tup4, Tuple[str, torch.Tensor]): + t = tup4[1] + if isinstance(t, torch.Tensor): + l[0] = (tup4[0], torch.add(t, t)) + return l + + self.checkScript(fn, (torch.arange(5),)) + + graph = torch.jit.script(fn).graph + + print(graph) + + # Check that we're making a `List[Tuple[str, Any]]` + FileCheck().check(r"(str, Any)[] = prim::ListConstruct").run(graph) + + def test_list_type_refinement_defaults_to_Any_list_comprehension(self): + def fn(x): + tup1 = ("foo", torch.tensor(2)) + tup2 = ("bar", {"23": torch.tensor(3)}) + tup3 = ("baz", x) + l_ = [tup1, tup2] + l = [t for t in l_] # noqa: C416 + l.append(tup3) + tup4 = l[0] + if torch.jit.isinstance(tup4, Tuple[str, torch.Tensor]): + t = tup4[1] + if isinstance(t, torch.Tensor): + l[0] = (tup4[0], torch.add(t, t)) + return l + + self.checkScript(fn, (torch.arange(5),)) + + graph = torch.jit.script(fn).graph + + print(graph) + + # Check that we're making a `List[Tuple[str, Any]]` + FileCheck().check(r"(str, Any)[] = prim::ListConstruct").run(graph) + + def test_list_type_refinement_annotation_element_mismatch(self): + def fn(): + l: List[int] = [1, 2, "foo", 3] + return l + + with self.assertRaisesRegex(RuntimeError, "List type annotation" + r" `List\[int\]` did not match the " + "types of the given list elements"): + torch.jit.script(fn) + + def test_dict_type_refinement_defaults_to_Any_dict_creation(self): + def fn(x): + d = dict(foo=torch.tensor(2), + bar={"23": torch.tensor(3)}) + d["baz"] = x + t = d["foo"] + if isinstance(t, torch.Tensor): + d["bar"] = torch.add(t, t) + return d + + self.checkScript(fn, (torch.arange(5),)) + + graph = torch.jit.script(fn).graph + + FileCheck().check(r"Dict(str, Any) = prim::DictConstruct").run(graph) + + def test_dict_type_refinement_defaults_to_Any_dict_comprehension(self): + def fn(x): + d = {"foo": torch.tensor(2), + "bar": {"23": torch.tensor(3)}} + d["baz"] = x + t = d["foo"] + if isinstance(t, torch.Tensor): + d["bar"] = torch.add(t, t) + return d + + self.checkScript(fn, (torch.arange(5),)) + + graph = torch.jit.script(fn).graph + + FileCheck().check("Dict(str, Any) = prim::DictConstruct").run(graph) + + def test_dict_type_refinement_annotation_key_mismatch(self): + def fn(): + l1 = [1, 2, "foo", 3] + l2 = ["foo", "bar", "baz", "qux"] + d: Dict[int, str] = {k : v for k, v in zip(l1, l2)} + return l + + with self.assertRaisesRegex(RuntimeError, "Dict type annotation" + r" `Dict\[int, str\]` did not match" + " the type of an actual key type"): + torch.jit.script(fn) + + def test_dict_type_refinement_annotation_value_mismatch(self): + def fn(): + l1 = ["foo", "bar", "baz", "qux"] + l2 = [1, 2, "foo", 3] + d: Dict[str, int] = {k : v for k, v in zip(l1, l2)} + return l + + with self.assertRaisesRegex(RuntimeError, "Dict type annotation" + r" `Dict\[str, int\]` did not match" + " the type of an actual value " + "type"): + torch.jit.script(fn) def test_dict_invalid_annotations(self): # Check for invalid value type annotation @@ -200,16 +309,6 @@ class TestTyping(JitTestCase): self.checkScript(fn, []) self.checkScript(fn2, (torch.ones(2, 2),)) - with self.assertRaisesRegex(RuntimeError, "Could not unify"): - @torch.jit.script - def fn(): - return [1, 1.2] - - with self.assertRaisesRegex(RuntimeError, "Could not unify"): - @torch.jit.script - def fn(): - return [1, torch.ones(1, 2)] - # to avoid defining sum_list in multiple tests def get_sum_list_fn(self): def sum_list(a): diff --git a/test/test_jit.py b/test/test_jit.py index 3bfe6cf8419..5d4096b4aab 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -10584,7 +10584,10 @@ dedent """ def f5(a): torch.cat([3]) - with self.assertRaisesRegex(RuntimeError, 'Lists must contain only a single type'): + with self.assertRaisesRegex(RuntimeError, r'Expected a value of' + r' type \'List\[int\]\' for argument' + r' \'size\' but instead found type ' + r'\'List\[Any\]\''): @torch.jit.script def f6(a): a.expand(size=[3, [4]]) diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index 0133a48edb0..f54a3601bf1 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -1321,13 +1321,51 @@ struct to_ir { pushFrame(comprehension_block); WithInsertPoint guard(comprehension_block); auto emit_body = [&]() { - auto comprehension_out = emitExpr(lc.elt()); + Value* out = emitExpr(lc.elt()); + + // If we didn't have a type annotation, the type of the list would + // be set to `Tensor`. We don't want to unify this default type + // with the actual elements in the list, so let the type begin as + // the first element in the list if (!type_set) { - list_value->setType(ListType::create(comprehension_out->type())); + list_value->setType(ListType::create(out->type())); type_set = true; } + + ListTypePtr lt = list_value->type()->expect(); + + const TypePtr element_type_hint = + type_hint ? type_hint->expect()->getElementType() : nullptr; + + auto unified = unifyTypes( + lt->getElementType(), + out->type(), + /*default_to_any=*/true, + element_type_hint); + + if (lt->getElementType() != AnyType::get() && + *unified == AnyType::get()) { + TORCH_WARN( + "List consists of heterogeneous types, which means", + " that it has been typed as `List[Any]`. To use " + "any of the values in the List, it will be " + "necessary to add an `assert isinstance` statement " + "before first use to trigger type refinement. The first ", + "non-matching element was typed as ", + out->type()->repr_str(), + ", while the elements before it " + "were ", + lt->getElementType()->repr_str(), + "\n", + lc.range().str()); + } + + if (!type_hint) { + list_value->setType(ListType::create(*unified)); + } + NamedValue self = NamedValue(loc, "self", list_value); - NamedValue input = NamedValue(loc, "", comprehension_out); + NamedValue input = NamedValue(loc, "", out); emitBuiltinCall(loc, *graph, aten::append, {input}, {}, self); }; emitFor(targets_list, itrs, loc, emit_body); @@ -1366,10 +1404,88 @@ struct to_ir { auto emit_body = [&]() { auto k = emitExpr(dc.key()); auto v = emitExpr(dc.value()); + + // Make sure that any key and value types are subtypes of the + // annotatated key/value types + if (type_hint) { + DictTypePtr dict_type_hint = type_hint->expect(); + + std::stringstream ss; + std::stringstream err; + + bool is_key_subtype = + k->type()->isSubtypeOfExt(dict_type_hint->getKeyType(), &ss); + + if (!is_key_subtype) { + err << "Dict type annotation `" << dict_type_hint->repr_str() + << "` did not match the " + << "type of an actual key type `" << k->type()->repr_str() + << "`\n" + << ss.str(); + } + + ss.str(std::string()); + bool is_value_subtype = + v->type()->isSubtypeOfExt(dict_type_hint->getValueType(), &ss); + + if (!is_value_subtype) { + err << "Dict type annotation `" << dict_type_hint->repr_str() + << "` did not match the " + << "type of an actual value type `" << v->type()->repr_str() + << "`\n" + << ss.str(); + } + + if (!is_key_subtype || !is_value_subtype) { + throw ErrorReport(dc) << err.str(); + } + } + + // If we didn't have a type annotation, the type of the dict would + // be set to `(str, Tensor)`. We don't want to unify this default + // type with the actual elements in the dict, so let the type + // begin as the first element in the dict if (!type_set) { dict_value->setType(DictType::create(k->type(), v->type())); type_set = true; } + + DictTypePtr dt = dict_value->type()->expect(); + + const TypePtr value_type_hint = + type_hint ? type_hint->expect()->getKeyType() : nullptr; + + c10::optional unified = unifyTypes( + dt->getValueType(), + v->type(), + /*default_to_any=*/true, + value_type_hint); + + // Warn the user if we inferred the type of the values to be `Any` + // even though the annotation was something else + if (dt->getValueType() != AnyType::get() && *unified == AnyType::get()) { + TORCH_WARN( + "Dict consists of heterogeneous types, which means", + " that it has been typed as `Dict[str, Any]`. To use " + "any of the values in the Dict, it will be " + "necessary to add an `assert isinstance` statement " + "before first use to trigger type refinement. The first ", + "non-matching element was typed as ", + v->type()->repr_str(), + ", while the elements before it " + "were ", + dt->getValueType()->repr_str(), + "\n", + dc.range().str()); + } + + // We only want to set `dict_value` if we don't have a type hint + // to allow for the case that `*unified` is a subtype of + // the value type given by `type_hint` + if (!type_hint) { + dict_value->setType(DictType::create(k->type(), *unified)); + } + NamedValue self = NamedValue(loc, "self", dict_value); NamedValue input_k = NamedValue(loc, "", k); NamedValue input_v = NamedValue(loc, "", v); @@ -3534,6 +3650,66 @@ struct to_ir { ->call(tree->range(), method, named_values, {}, 0)); } + Value* emitListLiteral(ListLiteral ll, TypePtr type_hint) { + auto values = getValues(ll.inputs(), /*maybe_unpack=*/true); + + // Determine the element type of the list. If we have a type hint + // of `List[T]`, use `T`. If the list is non-empty, find the + // greatest common supertype of all the list elements (defaulting to + // `Any` as a catch-all supertype). Assume `[]` is `List[Tensor]` + TypePtr elem_type = TensorType::get(); + + if (type_hint) { + if (type_hint->kind() == TypeKind::ListType) { + elem_type = type_hint->expectRef().getElementType(); + } else { + // If the type hint was not `List[T]`, throw an error + throw ErrorReport(ll) << "Expected a List type hint but instead got " + << type_hint->repr_str(); + } + } + + if (!values.empty()) { + auto types = fmap(values, [](const Value* v) { return v->type(); }); + + std::stringstream nowhere; // never used + + const TypePtr element_type_hint = + type_hint ? type_hint->expect()->getElementType() : nullptr; + + c10::optional unified = unifyTypeList( + types, nowhere, /*default_to_any=*/true, element_type_hint); + + if (!type_hint && *unified == AnyType::get()) { + TORCH_WARN( + "List consists of heterogeneous types, which means", + " that it has been typed as `List[Any]`. To use " + "any of the values in the List, it will be " + "necessary to add an `assert isinstance` statement " + "before first use to trigger type refinement. \n", + ll.range().str()); + } + + if (type_hint && !(*unified)->isSubtypeOf(elem_type)) { + throw ErrorReport(ll) + << "List type annotation `" << type_hint->repr_str() + << "` did not match the types of the given list elements," + << " which were unified to " << (*unified)->repr_str(); + } + + // We only want to set `elem_type` if we don't have a type hint + // to allow for the case that `*unified` is a subtype of + // `type_hint` + if (!type_hint) { + elem_type = *unified; + } + } + + Value* result = + graph->insertNode(graph->createList(elem_type, values))->output(); + return result; + } + Value* emitSimpleExpr( const TreeRef& tree, const TypePtr& type_hint = nullptr) { @@ -3616,46 +3792,7 @@ struct to_ir { } break; case TK_LIST_LITERAL: { auto ll = ListLiteral(tree); - auto values = getValues(ll.inputs(), /*maybe_unpack=*/true); - - // determine the element type of the list - // if we have a type hint of List[T], use T - // if the list is non-empty use type_of(list[0]) - // otherwise assume it is List[Tensor] - TypePtr elem_type = TensorType::get(); - if (type_hint) { - if (type_hint->kind() == TypeKind::ListType) { - elem_type = type_hint->expectRef().getElementType(); - } else { - // If the type hint was not a List[T] throw an error - throw ErrorReport(tree) - << "Expected a List type hint but instead got " - << type_hint->repr_str(); - } - } else if (!values.empty()) { - std::stringstream ss; - auto types = fmap(values, [](const Value* v) { return v->type(); }); - auto maybe_elem_type = unifyTypeList(types, ss); - if (!maybe_elem_type) { - throw ErrorReport(tree) << "Lists must contain only a single type\n" - << ss.str(); - } - elem_type = maybe_elem_type.value(); - } - - for (auto v : values) { - std::stringstream ss; - if (!v->type()->isSubtypeOfExt(elem_type, &ss)) { - throw ErrorReport(tree) - << "Lists must contain only a single type, expected: " - << elem_type->repr_str() << " but found " - << v->type()->repr_str() << " instead.\n" - << ss.str(); - } - } - Value* result = - graph->insertNode(graph->createList(elem_type, values))->output(); - return result; + return emitListLiteral(ll, type_hint); } break; case TK_TUPLE_LITERAL: { auto ll = TupleLiteral(tree); @@ -3690,24 +3827,68 @@ struct to_ir { } AT_ASSERT(key_type != nullptr && value_type != nullptr); - auto checkTypeOfValues = [](const TypePtr& type, - const char* what, - const std::vector& values, - TreeList trees) { - for (size_t i = 0, N = values.size(); i < N; ++i) { - std::stringstream ss; - if (!values[i]->type()->isSubtypeOfExt(type, &ss)) { - throw ErrorReport(trees[i]) - << "Dict " << what - << " must contain only a single type, expected: " - << type->repr_str() << " but found " - << values[i]->type()->repr_str() << " instead.\n" - << ss.str(); + for (size_t i = 0; i < keys.size(); ++i) { + std::stringstream ss; + if (!keys[i]->type()->isSubtypeOfExt(key_type, &ss)) { + throw ErrorReport(key_trees[i]) + << "Dict keys must contain " + << "only a single type. Expected: " << key_type->repr_str() + << " but found " << keys[i]->type()->repr_str() << " instead.\n" + << ss.str(); + } + } + + if (!values.empty()) { + auto types = fmap(values, [](const Value* v) { return v->type(); }); + + std::stringstream nowhere; // never used + + const TypePtr value_type_hint = + type_hint ? type_hint->expect()->getKeyType() : nullptr; + + c10::optional unified = unifyTypeList( + types, + /*why_not=*/nowhere, + /*default_to_any=*/true, + value_type_hint); + + if (!type_hint && *unified == AnyType::get()) { + TORCH_WARN( + "Dict values consist of heterogeneous types, which " + "means that they have been typed as `Any`. To use " + "any of the values in the Dist, it will be " + "necessary to add an `assert isinstance` statement " + "before first use to trigger type refinement. \n", + dl.range().str()); + } + + if (type_hint) { + TypePtr value_type_hint = + type_hint->expect()->getValueType(); + for (size_t i = 0; i < types.size(); ++i) { + TORCH_CHECK( + types[i]->isSubtypeOf(value_type_hint), + "Type " + "hint for dict was", + type_hint->repr_str(), + "but the value ", + "at index ", + i, + " has type ", + types[i]->repr_str(), + ", which is not a valid" + " subtype of ", + value_type_hint->repr_str()); } } - }; - checkTypeOfValues(key_type, "keys", keys, key_trees); - checkTypeOfValues(value_type, "values", values, value_trees); + + // We only want to set `value_type` if we don't have a type + // hint to allow for the case that `*unified` is a subtype of + // the value type given by `type_hint` + if (!type_hint) { + value_type = *unified; + } + } return graph ->insertNode(graph->createDict(key_type, value_type, keys, values))