fix schema matching of tuples to vartype lists (#25944)

Summary:
In schema matching we allow a homogenous tuple to be matched to list arguments. This logic wasn't yet extended for vartype lists, causing stuff like `len((1, 2, 3))` to fail.

Fix for https://github.com/pytorch/pytorch/issues/20500
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25944

Differential Revision: D17431514

Pulled By: eellison

fbshipit-source-id: 2ad98bab15eaa496471df651572735eb35183323
This commit is contained in:
Elias Ellison 2019-09-17 13:46:20 -07:00 committed by Facebook Github Bot
parent 9181b9c73e
commit a8073f34af
7 changed files with 150 additions and 47 deletions

View file

@ -1294,7 +1294,7 @@ struct MatchTypeReturn {
}
private:
MatchTypeReturn()
MatchTypeReturn()
: reason_(c10::nullopt) {}
c10::optional<std::string> reason_; // is there is no match, this contains the reason
};
@ -1304,13 +1304,13 @@ struct MatchTypeReturn {
// and a r.reason() that describes why it could not match.
// note: It is possible to successfully match a formal, but for type variables
// in the formal to still not be defined. In particular, None matches Optional[T]
// but does not define the value of T.
// but does not define the value of T.
CAFFE2_API MatchTypeReturn
matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv& type_env);
// replace type variables appearing in `type` with the values in
// `type_env`. Returns nullptr if a variable used in `type`
// does not appear in `type_env`
// replace type variables appearing in `type` with the values in
// `type_env`. Returns nullptr if a variable used in `type`
// does not appear in `type_env`
CAFFE2_API TypePtr tryEvalTypeVariables(TypePtr type, TypeEnv& type_env);
/**

View file

@ -319,51 +319,70 @@ c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2) {
return c10::nullopt;
}
MatchTypeReturn matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv& type_env) {
if(!formal->hasFreeVariables()) {
c10::optional<TypePtr> unifyTypeList(at::ArrayRef<TypePtr> elements) {
if (elements.size() == 0) {
return c10::nullopt;
}
c10::optional<TypePtr> ret_type = elements[0];
for (size_t i = 1; i < elements.size() && ret_type; ++i) {
ret_type = unifyTypes(*ret_type, elements[i]);
}
return ret_type;
}
MatchTypeReturn matchTypeVariables(
TypePtr formal,
TypePtr actual,
TypeEnv& type_env) {
if (!formal->hasFreeVariables()) {
return MatchTypeReturn::Success();
}
if(auto vt = formal->cast<VarType>()) {
if (auto vt = formal->cast<VarType>()) {
auto it = type_env.find(vt->name());
if(it == type_env.end()) {
if (it == type_env.end()) {
type_env[vt->name()] = actual;
return MatchTypeReturn::Success();
} else if(auto unified = unifyTypes(it->second, actual)) {
} else if (auto unified = unifyTypes(it->second, actual)) {
type_env[vt->name()] = *unified;
return MatchTypeReturn::Success();
}
std::stringstream ss;
ss << "Type variable '" << vt->name() << "' previously matched to type " <<
it->second->python_str() << " is matched to type " << actual->python_str();
ss << "Type variable '" << vt->name() << "' previously matched to type "
<< it->second->python_str() << " is matched to type "
<< actual->python_str();
return ss.str();
} else if(auto lt_formal = formal->cast<ListType>()) {
if(auto lt_actual = actual->cast<ListType>()) {
} else if (auto lt_formal = formal->cast<ListType>()) {
if (auto lt_actual = actual->cast<ListType>()) {
const auto innerMatch = matchTypeVariables(
lt_formal->getElementType(),
lt_actual->getElementType(),
type_env);
lt_formal->getElementType(), lt_actual->getElementType(), type_env);
if (!innerMatch.success()) {
// propagate the errMsg onward
return innerMatch;
}
return MatchTypeReturn::Success();
} else {
std::stringstream ss;
ss << "Cannot match " << lt_formal->python_str() << " to "
<< actual->python_str();
return ss.str();
} else if (auto tup_type = actual->cast<TupleType>()) {
auto maybe_tuple_unified = unifyTypeList(tup_type->elements());
if (maybe_tuple_unified) {
return matchTypeVariables(
lt_formal->getElementType(), *maybe_tuple_unified, type_env);
}
}
} else if(auto tp_formal = formal->cast<TupleType>()) {
if(auto tp_actual = actual->cast<TupleType>()) {
if(tp_formal->elements().size() != tp_actual->elements().size()) {
std::stringstream ss;
ss << "Cannot match " << lt_formal->python_str() << " to "
<< actual->python_str();
return ss.str();
} else if (auto tp_formal = formal->cast<TupleType>()) {
if (auto tp_actual = actual->cast<TupleType>()) {
if (tp_formal->elements().size() != tp_actual->elements().size()) {
return MatchTypeReturn("Cannot match tuples of mismatched size");
}
for(size_t i = 0; i < tp_formal->elements().size(); ++i) {
for (size_t i = 0; i < tp_formal->elements().size(); ++i) {
const auto result = matchTypeVariables(
tp_formal->elements()[i],
tp_actual->elements()[i],
type_env);
tp_formal->elements()[i], tp_actual->elements()[i], type_env);
if (!result.success()) {
return result;
}
@ -401,26 +420,20 @@ MatchTypeReturn matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv& type
// unknown type).
return matchTypeVariables(opt_formal->getElementType(), actual, type_env);
}
// note: if actual was non here we potentially did not fill in the type variables
// contained in the formal. It is still a valid match because None matches Optional[T]
// later error checking on tryEvalTypeVariables will report the problem if we never match
// variables in type T
// note: if actual was non here we potentially did not fill in the type
// variables contained in the formal. It is still a valid match because None
// matches Optional[T] later error checking on tryEvalTypeVariables will
// report the problem if we never match variables in type T
return MatchTypeReturn::Success();
} else if (auto dict_formal = formal->cast<DictType>()) {
if (auto dict_actual = actual->cast<DictType>()) {
auto key_match = matchTypeVariables(
dict_formal->getKeyType(),
dict_actual->getKeyType(),
type_env
);
dict_formal->getKeyType(), dict_actual->getKeyType(), type_env);
if (!key_match.success()) {
return key_match;
}
auto value_match = matchTypeVariables(
dict_formal->getValueType(),
dict_actual->getValueType(),
type_env
);
dict_formal->getValueType(), dict_actual->getValueType(), type_env);
if (!value_match.success()) {
return value_match;
}

View file

@ -0,0 +1,79 @@
#include <ATen/test/test_assert.h>
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/testing/file_check.h>
#include <torch/jit.h>
#include "test/cpp/jit/test_base.h"
#include "torch/csrc/jit/custom_operator.h"
#include <sstream>
#include <string>
namespace torch {
namespace jit {
void testSchemaMatching() {
{
RegisterOperators reg({
Operator(
"aten::test_vartype(t[] a, t b) -> (t)",
[](const Node* node) {
return [](Stack& stack) {
c10::List<double> list;
double a;
pop(stack, list, a);
push(stack, a);
return 0;
};
}),
});
script::Module m("m");
m.define(R"(
def test(self):
a = (1.0, 2.0)
return torch.test_vartype(a, 2.0)
)");
auto result = m.run_method("test");
TORCH_INTERNAL_ASSERT(result.toDouble() == 2.0);
const std::string error_example = R"JIT(
def test_2(self):
a = (1.0, 2.0)
non_float = (1, 1)
return torch.test_vartype(a, non_float)
)JIT";
ASSERT_THROWSM(m.define(error_example), "previously matched to type");
}
{
RegisterOperators reg({
Operator(
"aten::test_vartype2(t a, t[] b) -> (t[])",
[](const Node* node) {
return [](Stack& stack) {
double a;
c10::List<double> list;
pop(stack, a, list);
push(stack, a);
return 0;
};
}),
});
script::Module m("m");
m.define(R"JIT(
def test(self):
a = (1.0, 2.0)
return torch.test_vartype2(3.0, a)
)JIT");
auto result = m.run_method("test");
TORCH_INTERNAL_ASSERT(result.toDouble() == 3.0);
static const auto error_exam2 = R"JIT(
def test_2(self):
a = (1, 2)
return torch.test_vartype2(3.0, a)
)JIT";
ASSERT_THROWSM(m.define(error_exam2), "previously matched to type");
}
}
} // namespace jit
} // namespace torch

View file

@ -19,6 +19,7 @@ namespace jit {
_(CustomOperatorAliasing) \
_(IValueKWargs) \
_(CustomFusion) \
_(SchemaMatching) \
_(Differentiate) \
_(DifferentiateWithRequiresGrad) \
_(FromQualString) \

View file

@ -6404,7 +6404,7 @@ a")
return ten1
''')
lists = ["2.5", "4", "True", "False", "[2]", "[-.5]", "[False, True, False]", "[2, 2]",
lists = ["2.5", "4", "True", "False", "[2]", "[-.5]", "[False, True, False]", "[2, 2]", "(1, 1)",
"torch.jit.annotate(List[int], [])", "[2.5, 2.5]", "[[2], [2]]", "[[-.5], [2.2]]", "[[False], [True]]"]
dtypes = ["", ", dtype=torch.float", ", dtype=torch.double", ", dtype=torch.half",
@ -10307,6 +10307,11 @@ a")
self.assertEqual(r.dtype, torch.float)
self.assertEqual(torch.zeros([1, 1, 2], dtype=torch.float), r)
def fn():
return torch.zeros((1, 2, 3))
self.checkScript(fn, ())
def test_vararg_zeros(self):
def foo():
return torch.zeros(3, 4, 5, dtype=torch.int)

View file

@ -226,8 +226,8 @@ bool Operator::matches(const Node* node) const {
TypeEnv type_env;
for (size_t i = 0; i < formals.size(); ++i) {
auto formal = formals[i].type();
const MatchTypeReturn matched_type =
matchTypeVariables(formal, actuals[i]->type(), type_env);
const MatchTypeReturn matched_type = matchTypeVariables(
formal, actuals[i]->type(), type_env);
if (!matched_type.success()) {
return false;
}

View file

@ -130,8 +130,8 @@ static Value* tryMatchArgument(
}
// Resolve VarType variables
const MatchTypeReturn matched =
matchTypeVariables(arg.type(), value->type(), type_env);
const MatchTypeReturn matched = matchTypeVariables(
arg.type(), value->type(), type_env);
if (!matched.success()) {
if (failure_messages) {
err() << "Could not match type " << value->type()->python_str() << " to "
@ -242,11 +242,16 @@ static bool varargsCanBeUsedAsList(
// The formal must be a list
bool argument_is_list = arg.type()->kind() == TypeKind::ListType;
// matching varargs of typevar list nyi
bool typevar_list = argument_is_list &&
arg.type()->cast<ListType>()->getElementType()->cast<VarType>();
// it must not be a broadcasting list like int[3],
// otherwise a single int is a valid input
bool arg_is_broadcasting_list = bool(arg.N());
return is_last_argument && argument_is_list & !arg_is_broadcasting_list;
return is_last_argument && argument_is_list & !arg_is_broadcasting_list &&
!typevar_list;
}
c10::optional<MatchedSchema> tryMatchSchema(