From f5c10fdbd3f1a6b2ec458dc4411dc3b3c69f2350 Mon Sep 17 00:00:00 2001 From: Ansley Ussery Date: Sat, 10 Jul 2021 14:27:49 -0700 Subject: [PATCH] Allow for heterogenous List and Dict values + Improve container typing algorithm (#57137) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57137 This PR corrects and expands our typing algorithm for unannotated, non-empty dicts and lists. Previously, to verify type correctness for an unannotated, non-empty container, we had gotten the type of the first element in the container, then checked if each following element was a subtype of the first type. That's too restrictive--what if the first element were a subtype of the second element? Instead, we should type the container by getting the smallest common supertype of all the given elements. We need slightly different rules for keys and values in dicts, though: because the set of key types is restricted, finding two key types that cannot be unified should cause an error. On the other hand, the set of value types is not restricted, so we should be able to use `Any` as a valid supertype. We need to keep the set of keys restricted since the keys are used to generate and match schemas. This does not break backwards compatibility, because the default element type is the smallest supertype of all the given types. So, if someone creates an unannotated dict where the keys are all `str` and the values are all `torch.Tensor`, the dict will be inferred to `Dict[str, Tensor]` just like it was before. Empty lists are still typed as `List[torch.Tensor],` and empty dicts are still typed as `Dict[str, Tensor]`. This PR unblocks three engineers on an FB-internal team and improves FX-TorchScript compatibility. Test Plan: Imported from OSS Reviewed By: gmagogsfm Differential Revision: D28231839 Pulled By: ansley fbshipit-source-id: 7297bf239749daa54895add708185c75e6ca5999 --- aten/src/ATen/core/jit_type.h | 12 +- aten/src/ATen/core/type.cpp | 23 +- test/HowToWriteTestsUsingFileCheck.md | 3 + test/jit/test_list_dict.py | 8 +- test/jit/test_types.py | 4 +- test/jit/test_typing.py | 129 +++++++++-- test/test_jit.py | 5 +- torch/csrc/jit/frontend/ir_emitter.cpp | 299 ++++++++++++++++++++----- 8 files changed, 393 insertions(+), 90 deletions(-) diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 9fafdfd896f..d733fbd2da5 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -1574,15 +1574,21 @@ inline at::ScalarType scalarTypeFromJitType(const c10::TypePtr& type) { // then t2 will be returned (and vice versa). // Two different tensortypes will return dynamic. // Currently we chose not to support returning a NumberType for a float & int -// input because of a lack of operator support for NumberType +// input because of a lack of operator support for NumberType. +// If `type_hint` is an `InterfaceType`, then we can use that as a +// potential supertype for `ClassType`s in the list. Otherwise, we have +// no way to find and use some common interface type TORCH_API c10::optional unifyTypes( const TypePtr& t1, const TypePtr& t2, - bool default_to_any = false); + bool default_to_any = false, + TypePtr type_hint=nullptr); TORCH_API c10::optional unifyTypeList( at::ArrayRef elements, - std::ostream& why_not); + std::ostream& why_not, + bool default_to_any=false, + TypePtr type_hint=nullptr); namespace detail { template diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index fa467b7a005..e8cffebcfaf 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -265,7 +265,7 @@ AnyEnumTypePtr AnyEnumType::get() { return value; } -c10::optional unifyTypesImpl(const TypePtr& t1, const TypePtr& t2) { +c10::optional unifyTypesImpl(const TypePtr& t1, const TypePtr& t2, bool default_to_any=false, TypePtr type_hint=nullptr) { // check direct subtyping relation if (t1->isSubtypeOf(t2)) { return t2; @@ -308,7 +308,7 @@ c10::optional unifyTypesImpl(const TypePtr& t1, const TypePtr& t2) { } std::vector elements; for (size_t i = 0; i < tuple1->elements().size(); i++) { - if (auto elem = unifyTypes(tuple1->elements().at(i), tuple2->elements().at(i))) { + if (auto elem = unifyTypes(tuple1->elements().at(i), tuple2->elements().at(i), default_to_any)) { elements.push_back(*elem); } else { return c10::nullopt; @@ -337,11 +337,18 @@ c10::optional unifyTypesImpl(const TypePtr& t1, const TypePtr& t2) { return t1_unshaped; } + // Check whether or not `type_hint` is a common parent. This case + // could occur if we had two class types that had been annotated with + // a common interface + if (type_hint && t1->isSubtypeOf(type_hint) && t2->isSubtypeOf(type_hint)) { + return type_hint; + } + return c10::nullopt; } -c10::optional unifyTypes(const TypePtr& t1, const TypePtr& t2, bool default_to_any) { - auto unified = unifyTypesImpl(t1, t2); +c10::optional unifyTypes(const TypePtr& t1, const TypePtr& t2, bool default_to_any, TypePtr type_hint) { + auto unified = unifyTypesImpl(t1, t2, default_to_any, type_hint); if (default_to_any && !unified) { return AnyType::get(); @@ -352,7 +359,9 @@ c10::optional unifyTypes(const TypePtr& t1, const TypePtr& t2, bool def c10::optional unifyTypeList( at::ArrayRef elements, - std::ostream& why_not) { + std::ostream& why_not, + bool default_to_any, + TypePtr type_hint) { if (elements.size() == 0) { why_not << "Cannot get unified type from empty list"; return c10::nullopt; @@ -360,7 +369,7 @@ c10::optional unifyTypeList( TypePtr ret_type = elements.at(0); for (size_t i = 1; i < elements.size() && ret_type; ++i) { - auto maybe_unified = unifyTypes(ret_type, elements.at(i)); + c10::optional maybe_unified = unifyTypes(ret_type, elements.at(i), default_to_any, type_hint); if (!maybe_unified) { why_not << "Could not unify type list since element " << i << " of type " << elements.at(i)->repr_str() @@ -368,7 +377,7 @@ c10::optional unifyTypeList( << ret_type->repr_str() << ")"; return c10::nullopt; } - ret_type = maybe_unified.value(); + ret_type = *maybe_unified; } return ret_type; diff --git a/test/HowToWriteTestsUsingFileCheck.md b/test/HowToWriteTestsUsingFileCheck.md index 429d7a06b48..0795c23002a 100644 --- a/test/HowToWriteTestsUsingFileCheck.md +++ b/test/HowToWriteTestsUsingFileCheck.md @@ -79,6 +79,9 @@ annotations from the example above one would write: * `CHECK: ` Scans the input until `PATTERN` is found. Fails if the pattern is not found. +* `CHECK-NEXT: ` + Scans the input on the line immediately following the previous CHECK until + `PATTERN` is found. Fails if the pattern is not found on that line. * `CHECK-NOT: ` Scans the input and fails if `PATTERN` is found on any line. The scan stops when a match for a next `CHECK` is found. diff --git a/test/jit/test_list_dict.py b/test/jit/test_list_dict.py index a7d30dae874..d8434515291 100644 --- a/test/jit/test_list_dict.py +++ b/test/jit/test_list_dict.py @@ -244,10 +244,10 @@ class TestList(JitTestCase): self.checkScript(fn, ()) def test_dict_keyword_with_mismatched_annotations(self): - # TODO: This fails during function schema matching, so the error - # message is not very informative to the user. Change logic so - # that the error is thrown at a different time? - err_msg = "Arguments for call are not valid" + err_msg = r"Dict type annotation `Dict\[int, str\]` did not "\ + "match the types of the actual dict items" + err_msg = r"Dict type annotation `Dict\[int, str\]` did not "\ + "match the type of an actual key type `str`" highlight_msg = "dict([(\"foo\", 1), (\"bar\", 2), (\"baz\", 3" with self.assertRaisesRegexWithHighlight(RuntimeError, err_msg, highlight_msg): @torch.jit.script diff --git a/test/jit/test_types.py b/test/jit/test_types.py index e7edc4734b4..5da4efde374 100644 --- a/test/jit/test_types.py +++ b/test/jit/test_types.py @@ -140,7 +140,9 @@ class TestTypesAndAnnotation(JitTestCase): wrong : List[int] = [0.5] return wrong - with self.assertRaisesRegex(RuntimeError, "Lists must contain only a single type"): + with self.assertRaisesRegex(RuntimeError, "List type annotation" + r" `List\[int\]` did not match the " + "types of the given list elements"): torch.jit.script(wrong_type) def test_optional_no_element_type_annotation(self): diff --git a/test/jit/test_typing.py b/test/jit/test_typing.py index f6797527556..f60f25f782e 100644 --- a/test/jit/test_typing.py +++ b/test/jit/test_typing.py @@ -2,6 +2,7 @@ import os import sys import torch +from torch.testing import FileCheck from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.common_utils import IS_WINDOWS from collections import namedtuple @@ -73,11 +74,119 @@ class TestTyping(JitTestCase): self.checkScript(test_dict_tensor_key, (dict_a, inp1)) self.checkScript(test_dict_tensor_key, (dict_a, inp2)) - def test_dict_types(self): - with self.assertRaisesRegex(RuntimeError, "single type"): - @torch.jit.script - def foo(): - new_item = {'score': [1.0], 'ys': [1, 2, 3]} + def test_list_type_refinement_defaults_to_Any_list_creation(self): + def fn(x): + tup1 = ("foo", torch.tensor(2)) + tup2 = ("bar", {"23": torch.tensor(3)}) + tup3 = ("baz", x) + l = list((tup1, tup2)) # noqa: C410 + l.append(tup3) + tup4 = l[0] + if torch.jit.isinstance(tup4, Tuple[str, torch.Tensor]): + t = tup4[1] + if isinstance(t, torch.Tensor): + l[0] = (tup4[0], torch.add(t, t)) + return l + + self.checkScript(fn, (torch.arange(5),)) + + graph = torch.jit.script(fn).graph + + print(graph) + + # Check that we're making a `List[Tuple[str, Any]]` + FileCheck().check(r"(str, Any)[] = prim::ListConstruct").run(graph) + + def test_list_type_refinement_defaults_to_Any_list_comprehension(self): + def fn(x): + tup1 = ("foo", torch.tensor(2)) + tup2 = ("bar", {"23": torch.tensor(3)}) + tup3 = ("baz", x) + l_ = [tup1, tup2] + l = [t for t in l_] # noqa: C416 + l.append(tup3) + tup4 = l[0] + if torch.jit.isinstance(tup4, Tuple[str, torch.Tensor]): + t = tup4[1] + if isinstance(t, torch.Tensor): + l[0] = (tup4[0], torch.add(t, t)) + return l + + self.checkScript(fn, (torch.arange(5),)) + + graph = torch.jit.script(fn).graph + + print(graph) + + # Check that we're making a `List[Tuple[str, Any]]` + FileCheck().check(r"(str, Any)[] = prim::ListConstruct").run(graph) + + def test_list_type_refinement_annotation_element_mismatch(self): + def fn(): + l: List[int] = [1, 2, "foo", 3] + return l + + with self.assertRaisesRegex(RuntimeError, "List type annotation" + r" `List\[int\]` did not match the " + "types of the given list elements"): + torch.jit.script(fn) + + def test_dict_type_refinement_defaults_to_Any_dict_creation(self): + def fn(x): + d = dict(foo=torch.tensor(2), + bar={"23": torch.tensor(3)}) + d["baz"] = x + t = d["foo"] + if isinstance(t, torch.Tensor): + d["bar"] = torch.add(t, t) + return d + + self.checkScript(fn, (torch.arange(5),)) + + graph = torch.jit.script(fn).graph + + FileCheck().check(r"Dict(str, Any) = prim::DictConstruct").run(graph) + + def test_dict_type_refinement_defaults_to_Any_dict_comprehension(self): + def fn(x): + d = {"foo": torch.tensor(2), + "bar": {"23": torch.tensor(3)}} + d["baz"] = x + t = d["foo"] + if isinstance(t, torch.Tensor): + d["bar"] = torch.add(t, t) + return d + + self.checkScript(fn, (torch.arange(5),)) + + graph = torch.jit.script(fn).graph + + FileCheck().check("Dict(str, Any) = prim::DictConstruct").run(graph) + + def test_dict_type_refinement_annotation_key_mismatch(self): + def fn(): + l1 = [1, 2, "foo", 3] + l2 = ["foo", "bar", "baz", "qux"] + d: Dict[int, str] = {k : v for k, v in zip(l1, l2)} + return l + + with self.assertRaisesRegex(RuntimeError, "Dict type annotation" + r" `Dict\[int, str\]` did not match" + " the type of an actual key type"): + torch.jit.script(fn) + + def test_dict_type_refinement_annotation_value_mismatch(self): + def fn(): + l1 = ["foo", "bar", "baz", "qux"] + l2 = [1, 2, "foo", 3] + d: Dict[str, int] = {k : v for k, v in zip(l1, l2)} + return l + + with self.assertRaisesRegex(RuntimeError, "Dict type annotation" + r" `Dict\[str, int\]` did not match" + " the type of an actual value " + "type"): + torch.jit.script(fn) def test_dict_invalid_annotations(self): # Check for invalid value type annotation @@ -200,16 +309,6 @@ class TestTyping(JitTestCase): self.checkScript(fn, []) self.checkScript(fn2, (torch.ones(2, 2),)) - with self.assertRaisesRegex(RuntimeError, "Could not unify"): - @torch.jit.script - def fn(): - return [1, 1.2] - - with self.assertRaisesRegex(RuntimeError, "Could not unify"): - @torch.jit.script - def fn(): - return [1, torch.ones(1, 2)] - # to avoid defining sum_list in multiple tests def get_sum_list_fn(self): def sum_list(a): diff --git a/test/test_jit.py b/test/test_jit.py index 3bfe6cf8419..5d4096b4aab 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -10584,7 +10584,10 @@ dedent """ def f5(a): torch.cat([3]) - with self.assertRaisesRegex(RuntimeError, 'Lists must contain only a single type'): + with self.assertRaisesRegex(RuntimeError, r'Expected a value of' + r' type \'List\[int\]\' for argument' + r' \'size\' but instead found type ' + r'\'List\[Any\]\''): @torch.jit.script def f6(a): a.expand(size=[3, [4]]) diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index 0133a48edb0..f54a3601bf1 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -1321,13 +1321,51 @@ struct to_ir { pushFrame(comprehension_block); WithInsertPoint guard(comprehension_block); auto emit_body = [&]() { - auto comprehension_out = emitExpr(lc.elt()); + Value* out = emitExpr(lc.elt()); + + // If we didn't have a type annotation, the type of the list would + // be set to `Tensor`. We don't want to unify this default type + // with the actual elements in the list, so let the type begin as + // the first element in the list if (!type_set) { - list_value->setType(ListType::create(comprehension_out->type())); + list_value->setType(ListType::create(out->type())); type_set = true; } + + ListTypePtr lt = list_value->type()->expect(); + + const TypePtr element_type_hint = + type_hint ? type_hint->expect()->getElementType() : nullptr; + + auto unified = unifyTypes( + lt->getElementType(), + out->type(), + /*default_to_any=*/true, + element_type_hint); + + if (lt->getElementType() != AnyType::get() && + *unified == AnyType::get()) { + TORCH_WARN( + "List consists of heterogeneous types, which means", + " that it has been typed as `List[Any]`. To use " + "any of the values in the List, it will be " + "necessary to add an `assert isinstance` statement " + "before first use to trigger type refinement. The first ", + "non-matching element was typed as ", + out->type()->repr_str(), + ", while the elements before it " + "were ", + lt->getElementType()->repr_str(), + "\n", + lc.range().str()); + } + + if (!type_hint) { + list_value->setType(ListType::create(*unified)); + } + NamedValue self = NamedValue(loc, "self", list_value); - NamedValue input = NamedValue(loc, "", comprehension_out); + NamedValue input = NamedValue(loc, "", out); emitBuiltinCall(loc, *graph, aten::append, {input}, {}, self); }; emitFor(targets_list, itrs, loc, emit_body); @@ -1366,10 +1404,88 @@ struct to_ir { auto emit_body = [&]() { auto k = emitExpr(dc.key()); auto v = emitExpr(dc.value()); + + // Make sure that any key and value types are subtypes of the + // annotatated key/value types + if (type_hint) { + DictTypePtr dict_type_hint = type_hint->expect(); + + std::stringstream ss; + std::stringstream err; + + bool is_key_subtype = + k->type()->isSubtypeOfExt(dict_type_hint->getKeyType(), &ss); + + if (!is_key_subtype) { + err << "Dict type annotation `" << dict_type_hint->repr_str() + << "` did not match the " + << "type of an actual key type `" << k->type()->repr_str() + << "`\n" + << ss.str(); + } + + ss.str(std::string()); + bool is_value_subtype = + v->type()->isSubtypeOfExt(dict_type_hint->getValueType(), &ss); + + if (!is_value_subtype) { + err << "Dict type annotation `" << dict_type_hint->repr_str() + << "` did not match the " + << "type of an actual value type `" << v->type()->repr_str() + << "`\n" + << ss.str(); + } + + if (!is_key_subtype || !is_value_subtype) { + throw ErrorReport(dc) << err.str(); + } + } + + // If we didn't have a type annotation, the type of the dict would + // be set to `(str, Tensor)`. We don't want to unify this default + // type with the actual elements in the dict, so let the type + // begin as the first element in the dict if (!type_set) { dict_value->setType(DictType::create(k->type(), v->type())); type_set = true; } + + DictTypePtr dt = dict_value->type()->expect(); + + const TypePtr value_type_hint = + type_hint ? type_hint->expect()->getKeyType() : nullptr; + + c10::optional unified = unifyTypes( + dt->getValueType(), + v->type(), + /*default_to_any=*/true, + value_type_hint); + + // Warn the user if we inferred the type of the values to be `Any` + // even though the annotation was something else + if (dt->getValueType() != AnyType::get() && *unified == AnyType::get()) { + TORCH_WARN( + "Dict consists of heterogeneous types, which means", + " that it has been typed as `Dict[str, Any]`. To use " + "any of the values in the Dict, it will be " + "necessary to add an `assert isinstance` statement " + "before first use to trigger type refinement. The first ", + "non-matching element was typed as ", + v->type()->repr_str(), + ", while the elements before it " + "were ", + dt->getValueType()->repr_str(), + "\n", + dc.range().str()); + } + + // We only want to set `dict_value` if we don't have a type hint + // to allow for the case that `*unified` is a subtype of + // the value type given by `type_hint` + if (!type_hint) { + dict_value->setType(DictType::create(k->type(), *unified)); + } + NamedValue self = NamedValue(loc, "self", dict_value); NamedValue input_k = NamedValue(loc, "", k); NamedValue input_v = NamedValue(loc, "", v); @@ -3534,6 +3650,66 @@ struct to_ir { ->call(tree->range(), method, named_values, {}, 0)); } + Value* emitListLiteral(ListLiteral ll, TypePtr type_hint) { + auto values = getValues(ll.inputs(), /*maybe_unpack=*/true); + + // Determine the element type of the list. If we have a type hint + // of `List[T]`, use `T`. If the list is non-empty, find the + // greatest common supertype of all the list elements (defaulting to + // `Any` as a catch-all supertype). Assume `[]` is `List[Tensor]` + TypePtr elem_type = TensorType::get(); + + if (type_hint) { + if (type_hint->kind() == TypeKind::ListType) { + elem_type = type_hint->expectRef().getElementType(); + } else { + // If the type hint was not `List[T]`, throw an error + throw ErrorReport(ll) << "Expected a List type hint but instead got " + << type_hint->repr_str(); + } + } + + if (!values.empty()) { + auto types = fmap(values, [](const Value* v) { return v->type(); }); + + std::stringstream nowhere; // never used + + const TypePtr element_type_hint = + type_hint ? type_hint->expect()->getElementType() : nullptr; + + c10::optional unified = unifyTypeList( + types, nowhere, /*default_to_any=*/true, element_type_hint); + + if (!type_hint && *unified == AnyType::get()) { + TORCH_WARN( + "List consists of heterogeneous types, which means", + " that it has been typed as `List[Any]`. To use " + "any of the values in the List, it will be " + "necessary to add an `assert isinstance` statement " + "before first use to trigger type refinement. \n", + ll.range().str()); + } + + if (type_hint && !(*unified)->isSubtypeOf(elem_type)) { + throw ErrorReport(ll) + << "List type annotation `" << type_hint->repr_str() + << "` did not match the types of the given list elements," + << " which were unified to " << (*unified)->repr_str(); + } + + // We only want to set `elem_type` if we don't have a type hint + // to allow for the case that `*unified` is a subtype of + // `type_hint` + if (!type_hint) { + elem_type = *unified; + } + } + + Value* result = + graph->insertNode(graph->createList(elem_type, values))->output(); + return result; + } + Value* emitSimpleExpr( const TreeRef& tree, const TypePtr& type_hint = nullptr) { @@ -3616,46 +3792,7 @@ struct to_ir { } break; case TK_LIST_LITERAL: { auto ll = ListLiteral(tree); - auto values = getValues(ll.inputs(), /*maybe_unpack=*/true); - - // determine the element type of the list - // if we have a type hint of List[T], use T - // if the list is non-empty use type_of(list[0]) - // otherwise assume it is List[Tensor] - TypePtr elem_type = TensorType::get(); - if (type_hint) { - if (type_hint->kind() == TypeKind::ListType) { - elem_type = type_hint->expectRef().getElementType(); - } else { - // If the type hint was not a List[T] throw an error - throw ErrorReport(tree) - << "Expected a List type hint but instead got " - << type_hint->repr_str(); - } - } else if (!values.empty()) { - std::stringstream ss; - auto types = fmap(values, [](const Value* v) { return v->type(); }); - auto maybe_elem_type = unifyTypeList(types, ss); - if (!maybe_elem_type) { - throw ErrorReport(tree) << "Lists must contain only a single type\n" - << ss.str(); - } - elem_type = maybe_elem_type.value(); - } - - for (auto v : values) { - std::stringstream ss; - if (!v->type()->isSubtypeOfExt(elem_type, &ss)) { - throw ErrorReport(tree) - << "Lists must contain only a single type, expected: " - << elem_type->repr_str() << " but found " - << v->type()->repr_str() << " instead.\n" - << ss.str(); - } - } - Value* result = - graph->insertNode(graph->createList(elem_type, values))->output(); - return result; + return emitListLiteral(ll, type_hint); } break; case TK_TUPLE_LITERAL: { auto ll = TupleLiteral(tree); @@ -3690,24 +3827,68 @@ struct to_ir { } AT_ASSERT(key_type != nullptr && value_type != nullptr); - auto checkTypeOfValues = [](const TypePtr& type, - const char* what, - const std::vector& values, - TreeList trees) { - for (size_t i = 0, N = values.size(); i < N; ++i) { - std::stringstream ss; - if (!values[i]->type()->isSubtypeOfExt(type, &ss)) { - throw ErrorReport(trees[i]) - << "Dict " << what - << " must contain only a single type, expected: " - << type->repr_str() << " but found " - << values[i]->type()->repr_str() << " instead.\n" - << ss.str(); + for (size_t i = 0; i < keys.size(); ++i) { + std::stringstream ss; + if (!keys[i]->type()->isSubtypeOfExt(key_type, &ss)) { + throw ErrorReport(key_trees[i]) + << "Dict keys must contain " + << "only a single type. Expected: " << key_type->repr_str() + << " but found " << keys[i]->type()->repr_str() << " instead.\n" + << ss.str(); + } + } + + if (!values.empty()) { + auto types = fmap(values, [](const Value* v) { return v->type(); }); + + std::stringstream nowhere; // never used + + const TypePtr value_type_hint = + type_hint ? type_hint->expect()->getKeyType() : nullptr; + + c10::optional unified = unifyTypeList( + types, + /*why_not=*/nowhere, + /*default_to_any=*/true, + value_type_hint); + + if (!type_hint && *unified == AnyType::get()) { + TORCH_WARN( + "Dict values consist of heterogeneous types, which " + "means that they have been typed as `Any`. To use " + "any of the values in the Dist, it will be " + "necessary to add an `assert isinstance` statement " + "before first use to trigger type refinement. \n", + dl.range().str()); + } + + if (type_hint) { + TypePtr value_type_hint = + type_hint->expect()->getValueType(); + for (size_t i = 0; i < types.size(); ++i) { + TORCH_CHECK( + types[i]->isSubtypeOf(value_type_hint), + "Type " + "hint for dict was", + type_hint->repr_str(), + "but the value ", + "at index ", + i, + " has type ", + types[i]->repr_str(), + ", which is not a valid" + " subtype of ", + value_type_hint->repr_str()); } } - }; - checkTypeOfValues(key_type, "keys", keys, key_trees); - checkTypeOfValues(value_type, "values", values, value_trees); + + // We only want to set `value_type` if we don't have a type + // hint to allow for the case that `*unified` is a subtype of + // the value type given by `type_hint` + if (!type_hint) { + value_type = *unified; + } + } return graph ->insertNode(graph->createDict(key_type, value_type, keys, values))