From a8073f34afc4e2a708af833e6d5ca773ddd9d72b Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Tue, 17 Sep 2019 13:46:20 -0700 Subject: [PATCH] fix schema matching of tuples to vartype lists (#25944) Summary: In schema matching we allow a homogenous tuple to be matched to list arguments. This logic wasn't yet extended for vartype lists, causing stuff like `len((1, 2, 3))` to fail. Fix for https://github.com/pytorch/pytorch/issues/20500 Pull Request resolved: https://github.com/pytorch/pytorch/pull/25944 Differential Revision: D17431514 Pulled By: eellison fbshipit-source-id: 2ad98bab15eaa496471df651572735eb35183323 --- aten/src/ATen/core/jit_type.h | 10 +-- aten/src/ATen/core/type.cpp | 85 +++++++++++++---------- test/cpp/jit/test_schema_matching.cpp | 79 +++++++++++++++++++++ test/cpp/jit/tests.h | 1 + test/test_jit.py | 7 +- torch/csrc/jit/operator.cpp | 4 +- torch/csrc/jit/script/schema_matching.cpp | 11 ++- 7 files changed, 150 insertions(+), 47 deletions(-) create mode 100644 test/cpp/jit/test_schema_matching.cpp diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index a94228897d5..b8f335f4139 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -1294,7 +1294,7 @@ struct MatchTypeReturn { } private: - MatchTypeReturn() + MatchTypeReturn() : reason_(c10::nullopt) {} c10::optional reason_; // is there is no match, this contains the reason }; @@ -1304,13 +1304,13 @@ struct MatchTypeReturn { // and a r.reason() that describes why it could not match. // note: It is possible to successfully match a formal, but for type variables // in the formal to still not be defined. In particular, None matches Optional[T] -// but does not define the value of T. +// but does not define the value of T. CAFFE2_API MatchTypeReturn matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv& type_env); -// replace type variables appearing in `type` with the values in -// `type_env`. Returns nullptr if a variable used in `type` -// does not appear in `type_env` +// replace type variables appearing in `type` with the values in +// `type_env`. Returns nullptr if a variable used in `type` +// does not appear in `type_env` CAFFE2_API TypePtr tryEvalTypeVariables(TypePtr type, TypeEnv& type_env); /** diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index a084fae1514..37160843f36 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -319,51 +319,70 @@ c10::optional unifyTypes(const TypePtr& t1, const TypePtr& t2) { return c10::nullopt; } -MatchTypeReturn matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv& type_env) { - if(!formal->hasFreeVariables()) { +c10::optional unifyTypeList(at::ArrayRef elements) { + if (elements.size() == 0) { + return c10::nullopt; + } + + c10::optional ret_type = elements[0]; + for (size_t i = 1; i < elements.size() && ret_type; ++i) { + ret_type = unifyTypes(*ret_type, elements[i]); + } + + return ret_type; +} + +MatchTypeReturn matchTypeVariables( + TypePtr formal, + TypePtr actual, + TypeEnv& type_env) { + if (!formal->hasFreeVariables()) { return MatchTypeReturn::Success(); } - if(auto vt = formal->cast()) { + if (auto vt = formal->cast()) { auto it = type_env.find(vt->name()); - if(it == type_env.end()) { + if (it == type_env.end()) { type_env[vt->name()] = actual; return MatchTypeReturn::Success(); - } else if(auto unified = unifyTypes(it->second, actual)) { + } else if (auto unified = unifyTypes(it->second, actual)) { type_env[vt->name()] = *unified; return MatchTypeReturn::Success(); } std::stringstream ss; - ss << "Type variable '" << vt->name() << "' previously matched to type " << - it->second->python_str() << " is matched to type " << actual->python_str(); + ss << "Type variable '" << vt->name() << "' previously matched to type " + << it->second->python_str() << " is matched to type " + << actual->python_str(); return ss.str(); - } else if(auto lt_formal = formal->cast()) { - if(auto lt_actual = actual->cast()) { + } else if (auto lt_formal = formal->cast()) { + if (auto lt_actual = actual->cast()) { const auto innerMatch = matchTypeVariables( - lt_formal->getElementType(), - lt_actual->getElementType(), - type_env); + lt_formal->getElementType(), lt_actual->getElementType(), type_env); if (!innerMatch.success()) { // propagate the errMsg onward return innerMatch; } return MatchTypeReturn::Success(); - } else { - std::stringstream ss; - ss << "Cannot match " << lt_formal->python_str() << " to " - << actual->python_str(); - return ss.str(); + } else if (auto tup_type = actual->cast()) { + auto maybe_tuple_unified = unifyTypeList(tup_type->elements()); + if (maybe_tuple_unified) { + return matchTypeVariables( + lt_formal->getElementType(), *maybe_tuple_unified, type_env); + } } - } else if(auto tp_formal = formal->cast()) { - if(auto tp_actual = actual->cast()) { - if(tp_formal->elements().size() != tp_actual->elements().size()) { + + std::stringstream ss; + ss << "Cannot match " << lt_formal->python_str() << " to " + << actual->python_str(); + return ss.str(); + } else if (auto tp_formal = formal->cast()) { + if (auto tp_actual = actual->cast()) { + if (tp_formal->elements().size() != tp_actual->elements().size()) { return MatchTypeReturn("Cannot match tuples of mismatched size"); } - for(size_t i = 0; i < tp_formal->elements().size(); ++i) { + for (size_t i = 0; i < tp_formal->elements().size(); ++i) { const auto result = matchTypeVariables( - tp_formal->elements()[i], - tp_actual->elements()[i], - type_env); + tp_formal->elements()[i], tp_actual->elements()[i], type_env); if (!result.success()) { return result; } @@ -401,26 +420,20 @@ MatchTypeReturn matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv& type // unknown type). return matchTypeVariables(opt_formal->getElementType(), actual, type_env); } - // note: if actual was non here we potentially did not fill in the type variables - // contained in the formal. It is still a valid match because None matches Optional[T] - // later error checking on tryEvalTypeVariables will report the problem if we never match - // variables in type T + // note: if actual was non here we potentially did not fill in the type + // variables contained in the formal. It is still a valid match because None + // matches Optional[T] later error checking on tryEvalTypeVariables will + // report the problem if we never match variables in type T return MatchTypeReturn::Success(); } else if (auto dict_formal = formal->cast()) { if (auto dict_actual = actual->cast()) { auto key_match = matchTypeVariables( - dict_formal->getKeyType(), - dict_actual->getKeyType(), - type_env - ); + dict_formal->getKeyType(), dict_actual->getKeyType(), type_env); if (!key_match.success()) { return key_match; } auto value_match = matchTypeVariables( - dict_formal->getValueType(), - dict_actual->getValueType(), - type_env - ); + dict_formal->getValueType(), dict_actual->getValueType(), type_env); if (!value_match.success()) { return value_match; } diff --git a/test/cpp/jit/test_schema_matching.cpp b/test/cpp/jit/test_schema_matching.cpp new file mode 100644 index 00000000000..442cfa10b38 --- /dev/null +++ b/test/cpp/jit/test_schema_matching.cpp @@ -0,0 +1,79 @@ +#include +#include +#include +#include +#include "test/cpp/jit/test_base.h" +#include "torch/csrc/jit/custom_operator.h" + +#include +#include + +namespace torch { +namespace jit { + +void testSchemaMatching() { + { + RegisterOperators reg({ + Operator( + "aten::test_vartype(t[] a, t b) -> (t)", + [](const Node* node) { + return [](Stack& stack) { + c10::List list; + double a; + pop(stack, list, a); + push(stack, a); + return 0; + }; + }), + }); + script::Module m("m"); + m.define(R"( + def test(self): + a = (1.0, 2.0) + return torch.test_vartype(a, 2.0) + )"); + auto result = m.run_method("test"); + TORCH_INTERNAL_ASSERT(result.toDouble() == 2.0); + + const std::string error_example = R"JIT( + def test_2(self): + a = (1.0, 2.0) + non_float = (1, 1) + return torch.test_vartype(a, non_float) + )JIT"; + + ASSERT_THROWSM(m.define(error_example), "previously matched to type"); + } + { + RegisterOperators reg({ + Operator( + "aten::test_vartype2(t a, t[] b) -> (t[])", + [](const Node* node) { + return [](Stack& stack) { + double a; + c10::List list; + pop(stack, a, list); + push(stack, a); + return 0; + }; + }), + }); + script::Module m("m"); + m.define(R"JIT( + def test(self): + a = (1.0, 2.0) + return torch.test_vartype2(3.0, a) + )JIT"); + auto result = m.run_method("test"); + TORCH_INTERNAL_ASSERT(result.toDouble() == 3.0); + + static const auto error_exam2 = R"JIT( + def test_2(self): + a = (1, 2) + return torch.test_vartype2(3.0, a) + )JIT"; + ASSERT_THROWSM(m.define(error_exam2), "previously matched to type"); + } +} +} // namespace jit +} // namespace torch diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h index ef1cbd81c26..2ac35d99c3b 100644 --- a/test/cpp/jit/tests.h +++ b/test/cpp/jit/tests.h @@ -19,6 +19,7 @@ namespace jit { _(CustomOperatorAliasing) \ _(IValueKWargs) \ _(CustomFusion) \ + _(SchemaMatching) \ _(Differentiate) \ _(DifferentiateWithRequiresGrad) \ _(FromQualString) \ diff --git a/test/test_jit.py b/test/test_jit.py index 43c46991f93..17ddd19651a 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -6404,7 +6404,7 @@ a") return ten1 ''') - lists = ["2.5", "4", "True", "False", "[2]", "[-.5]", "[False, True, False]", "[2, 2]", + lists = ["2.5", "4", "True", "False", "[2]", "[-.5]", "[False, True, False]", "[2, 2]", "(1, 1)", "torch.jit.annotate(List[int], [])", "[2.5, 2.5]", "[[2], [2]]", "[[-.5], [2.2]]", "[[False], [True]]"] dtypes = ["", ", dtype=torch.float", ", dtype=torch.double", ", dtype=torch.half", @@ -10307,6 +10307,11 @@ a") self.assertEqual(r.dtype, torch.float) self.assertEqual(torch.zeros([1, 1, 2], dtype=torch.float), r) + def fn(): + return torch.zeros((1, 2, 3)) + + self.checkScript(fn, ()) + def test_vararg_zeros(self): def foo(): return torch.zeros(3, 4, 5, dtype=torch.int) diff --git a/torch/csrc/jit/operator.cpp b/torch/csrc/jit/operator.cpp index 3d56ae240e9..02b6a19ae77 100644 --- a/torch/csrc/jit/operator.cpp +++ b/torch/csrc/jit/operator.cpp @@ -226,8 +226,8 @@ bool Operator::matches(const Node* node) const { TypeEnv type_env; for (size_t i = 0; i < formals.size(); ++i) { auto formal = formals[i].type(); - const MatchTypeReturn matched_type = - matchTypeVariables(formal, actuals[i]->type(), type_env); + const MatchTypeReturn matched_type = matchTypeVariables( + formal, actuals[i]->type(), type_env); if (!matched_type.success()) { return false; } diff --git a/torch/csrc/jit/script/schema_matching.cpp b/torch/csrc/jit/script/schema_matching.cpp index 78d1adffe39..8dca641cb07 100644 --- a/torch/csrc/jit/script/schema_matching.cpp +++ b/torch/csrc/jit/script/schema_matching.cpp @@ -130,8 +130,8 @@ static Value* tryMatchArgument( } // Resolve VarType variables - const MatchTypeReturn matched = - matchTypeVariables(arg.type(), value->type(), type_env); + const MatchTypeReturn matched = matchTypeVariables( + arg.type(), value->type(), type_env); if (!matched.success()) { if (failure_messages) { err() << "Could not match type " << value->type()->python_str() << " to " @@ -242,11 +242,16 @@ static bool varargsCanBeUsedAsList( // The formal must be a list bool argument_is_list = arg.type()->kind() == TypeKind::ListType; + // matching varargs of typevar list nyi + bool typevar_list = argument_is_list && + arg.type()->cast()->getElementType()->cast(); + // it must not be a broadcasting list like int[3], // otherwise a single int is a valid input bool arg_is_broadcasting_list = bool(arg.N()); - return is_last_argument && argument_is_list & !arg_is_broadcasting_list; + return is_last_argument && argument_is_list & !arg_is_broadcasting_list && + !typevar_list; } c10::optional tryMatchSchema(