mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
fix schema matching of tuples to vartype lists (#25944)
Summary: In schema matching we allow a homogenous tuple to be matched to list arguments. This logic wasn't yet extended for vartype lists, causing stuff like `len((1, 2, 3))` to fail. Fix for https://github.com/pytorch/pytorch/issues/20500 Pull Request resolved: https://github.com/pytorch/pytorch/pull/25944 Differential Revision: D17431514 Pulled By: eellison fbshipit-source-id: 2ad98bab15eaa496471df651572735eb35183323
This commit is contained in:
parent
9181b9c73e
commit
a8073f34af
7 changed files with 150 additions and 47 deletions
|
|
@ -1294,7 +1294,7 @@ struct MatchTypeReturn {
|
|||
}
|
||||
|
||||
private:
|
||||
MatchTypeReturn()
|
||||
MatchTypeReturn()
|
||||
: reason_(c10::nullopt) {}
|
||||
c10::optional<std::string> reason_; // is there is no match, this contains the reason
|
||||
};
|
||||
|
|
@ -1304,13 +1304,13 @@ struct MatchTypeReturn {
|
|||
// and a r.reason() that describes why it could not match.
|
||||
// note: It is possible to successfully match a formal, but for type variables
|
||||
// in the formal to still not be defined. In particular, None matches Optional[T]
|
||||
// but does not define the value of T.
|
||||
// but does not define the value of T.
|
||||
CAFFE2_API MatchTypeReturn
|
||||
matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv& type_env);
|
||||
|
||||
// replace type variables appearing in `type` with the values in
|
||||
// `type_env`. Returns nullptr if a variable used in `type`
|
||||
// does not appear in `type_env`
|
||||
// replace type variables appearing in `type` with the values in
|
||||
// `type_env`. Returns nullptr if a variable used in `type`
|
||||
// does not appear in `type_env`
|
||||
CAFFE2_API TypePtr tryEvalTypeVariables(TypePtr type, TypeEnv& type_env);
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -319,51 +319,70 @@ c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2) {
|
|||
return c10::nullopt;
|
||||
}
|
||||
|
||||
MatchTypeReturn matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv& type_env) {
|
||||
if(!formal->hasFreeVariables()) {
|
||||
c10::optional<TypePtr> unifyTypeList(at::ArrayRef<TypePtr> elements) {
|
||||
if (elements.size() == 0) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
c10::optional<TypePtr> ret_type = elements[0];
|
||||
for (size_t i = 1; i < elements.size() && ret_type; ++i) {
|
||||
ret_type = unifyTypes(*ret_type, elements[i]);
|
||||
}
|
||||
|
||||
return ret_type;
|
||||
}
|
||||
|
||||
MatchTypeReturn matchTypeVariables(
|
||||
TypePtr formal,
|
||||
TypePtr actual,
|
||||
TypeEnv& type_env) {
|
||||
if (!formal->hasFreeVariables()) {
|
||||
return MatchTypeReturn::Success();
|
||||
}
|
||||
|
||||
if(auto vt = formal->cast<VarType>()) {
|
||||
if (auto vt = formal->cast<VarType>()) {
|
||||
auto it = type_env.find(vt->name());
|
||||
if(it == type_env.end()) {
|
||||
if (it == type_env.end()) {
|
||||
type_env[vt->name()] = actual;
|
||||
return MatchTypeReturn::Success();
|
||||
} else if(auto unified = unifyTypes(it->second, actual)) {
|
||||
} else if (auto unified = unifyTypes(it->second, actual)) {
|
||||
type_env[vt->name()] = *unified;
|
||||
return MatchTypeReturn::Success();
|
||||
}
|
||||
std::stringstream ss;
|
||||
ss << "Type variable '" << vt->name() << "' previously matched to type " <<
|
||||
it->second->python_str() << " is matched to type " << actual->python_str();
|
||||
ss << "Type variable '" << vt->name() << "' previously matched to type "
|
||||
<< it->second->python_str() << " is matched to type "
|
||||
<< actual->python_str();
|
||||
return ss.str();
|
||||
} else if(auto lt_formal = formal->cast<ListType>()) {
|
||||
if(auto lt_actual = actual->cast<ListType>()) {
|
||||
} else if (auto lt_formal = formal->cast<ListType>()) {
|
||||
if (auto lt_actual = actual->cast<ListType>()) {
|
||||
const auto innerMatch = matchTypeVariables(
|
||||
lt_formal->getElementType(),
|
||||
lt_actual->getElementType(),
|
||||
type_env);
|
||||
lt_formal->getElementType(), lt_actual->getElementType(), type_env);
|
||||
if (!innerMatch.success()) {
|
||||
// propagate the errMsg onward
|
||||
return innerMatch;
|
||||
}
|
||||
return MatchTypeReturn::Success();
|
||||
} else {
|
||||
std::stringstream ss;
|
||||
ss << "Cannot match " << lt_formal->python_str() << " to "
|
||||
<< actual->python_str();
|
||||
return ss.str();
|
||||
} else if (auto tup_type = actual->cast<TupleType>()) {
|
||||
auto maybe_tuple_unified = unifyTypeList(tup_type->elements());
|
||||
if (maybe_tuple_unified) {
|
||||
return matchTypeVariables(
|
||||
lt_formal->getElementType(), *maybe_tuple_unified, type_env);
|
||||
}
|
||||
}
|
||||
} else if(auto tp_formal = formal->cast<TupleType>()) {
|
||||
if(auto tp_actual = actual->cast<TupleType>()) {
|
||||
if(tp_formal->elements().size() != tp_actual->elements().size()) {
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "Cannot match " << lt_formal->python_str() << " to "
|
||||
<< actual->python_str();
|
||||
return ss.str();
|
||||
} else if (auto tp_formal = formal->cast<TupleType>()) {
|
||||
if (auto tp_actual = actual->cast<TupleType>()) {
|
||||
if (tp_formal->elements().size() != tp_actual->elements().size()) {
|
||||
return MatchTypeReturn("Cannot match tuples of mismatched size");
|
||||
}
|
||||
for(size_t i = 0; i < tp_formal->elements().size(); ++i) {
|
||||
for (size_t i = 0; i < tp_formal->elements().size(); ++i) {
|
||||
const auto result = matchTypeVariables(
|
||||
tp_formal->elements()[i],
|
||||
tp_actual->elements()[i],
|
||||
type_env);
|
||||
tp_formal->elements()[i], tp_actual->elements()[i], type_env);
|
||||
if (!result.success()) {
|
||||
return result;
|
||||
}
|
||||
|
|
@ -401,26 +420,20 @@ MatchTypeReturn matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv& type
|
|||
// unknown type).
|
||||
return matchTypeVariables(opt_formal->getElementType(), actual, type_env);
|
||||
}
|
||||
// note: if actual was non here we potentially did not fill in the type variables
|
||||
// contained in the formal. It is still a valid match because None matches Optional[T]
|
||||
// later error checking on tryEvalTypeVariables will report the problem if we never match
|
||||
// variables in type T
|
||||
// note: if actual was non here we potentially did not fill in the type
|
||||
// variables contained in the formal. It is still a valid match because None
|
||||
// matches Optional[T] later error checking on tryEvalTypeVariables will
|
||||
// report the problem if we never match variables in type T
|
||||
return MatchTypeReturn::Success();
|
||||
} else if (auto dict_formal = formal->cast<DictType>()) {
|
||||
if (auto dict_actual = actual->cast<DictType>()) {
|
||||
auto key_match = matchTypeVariables(
|
||||
dict_formal->getKeyType(),
|
||||
dict_actual->getKeyType(),
|
||||
type_env
|
||||
);
|
||||
dict_formal->getKeyType(), dict_actual->getKeyType(), type_env);
|
||||
if (!key_match.success()) {
|
||||
return key_match;
|
||||
}
|
||||
auto value_match = matchTypeVariables(
|
||||
dict_formal->getValueType(),
|
||||
dict_actual->getValueType(),
|
||||
type_env
|
||||
);
|
||||
dict_formal->getValueType(), dict_actual->getValueType(), type_env);
|
||||
if (!value_match.success()) {
|
||||
return value_match;
|
||||
}
|
||||
|
|
|
|||
79
test/cpp/jit/test_schema_matching.cpp
Normal file
79
test/cpp/jit/test_schema_matching.cpp
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
#include <ATen/test/test_assert.h>
|
||||
#include <torch/csrc/jit/ir.h>
|
||||
#include <torch/csrc/jit/testing/file_check.h>
|
||||
#include <torch/jit.h>
|
||||
#include "test/cpp/jit/test_base.h"
|
||||
#include "torch/csrc/jit/custom_operator.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
void testSchemaMatching() {
|
||||
{
|
||||
RegisterOperators reg({
|
||||
Operator(
|
||||
"aten::test_vartype(t[] a, t b) -> (t)",
|
||||
[](const Node* node) {
|
||||
return [](Stack& stack) {
|
||||
c10::List<double> list;
|
||||
double a;
|
||||
pop(stack, list, a);
|
||||
push(stack, a);
|
||||
return 0;
|
||||
};
|
||||
}),
|
||||
});
|
||||
script::Module m("m");
|
||||
m.define(R"(
|
||||
def test(self):
|
||||
a = (1.0, 2.0)
|
||||
return torch.test_vartype(a, 2.0)
|
||||
)");
|
||||
auto result = m.run_method("test");
|
||||
TORCH_INTERNAL_ASSERT(result.toDouble() == 2.0);
|
||||
|
||||
const std::string error_example = R"JIT(
|
||||
def test_2(self):
|
||||
a = (1.0, 2.0)
|
||||
non_float = (1, 1)
|
||||
return torch.test_vartype(a, non_float)
|
||||
)JIT";
|
||||
|
||||
ASSERT_THROWSM(m.define(error_example), "previously matched to type");
|
||||
}
|
||||
{
|
||||
RegisterOperators reg({
|
||||
Operator(
|
||||
"aten::test_vartype2(t a, t[] b) -> (t[])",
|
||||
[](const Node* node) {
|
||||
return [](Stack& stack) {
|
||||
double a;
|
||||
c10::List<double> list;
|
||||
pop(stack, a, list);
|
||||
push(stack, a);
|
||||
return 0;
|
||||
};
|
||||
}),
|
||||
});
|
||||
script::Module m("m");
|
||||
m.define(R"JIT(
|
||||
def test(self):
|
||||
a = (1.0, 2.0)
|
||||
return torch.test_vartype2(3.0, a)
|
||||
)JIT");
|
||||
auto result = m.run_method("test");
|
||||
TORCH_INTERNAL_ASSERT(result.toDouble() == 3.0);
|
||||
|
||||
static const auto error_exam2 = R"JIT(
|
||||
def test_2(self):
|
||||
a = (1, 2)
|
||||
return torch.test_vartype2(3.0, a)
|
||||
)JIT";
|
||||
ASSERT_THROWSM(m.define(error_exam2), "previously matched to type");
|
||||
}
|
||||
}
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -19,6 +19,7 @@ namespace jit {
|
|||
_(CustomOperatorAliasing) \
|
||||
_(IValueKWargs) \
|
||||
_(CustomFusion) \
|
||||
_(SchemaMatching) \
|
||||
_(Differentiate) \
|
||||
_(DifferentiateWithRequiresGrad) \
|
||||
_(FromQualString) \
|
||||
|
|
|
|||
|
|
@ -6404,7 +6404,7 @@ a")
|
|||
return ten1
|
||||
''')
|
||||
|
||||
lists = ["2.5", "4", "True", "False", "[2]", "[-.5]", "[False, True, False]", "[2, 2]",
|
||||
lists = ["2.5", "4", "True", "False", "[2]", "[-.5]", "[False, True, False]", "[2, 2]", "(1, 1)",
|
||||
"torch.jit.annotate(List[int], [])", "[2.5, 2.5]", "[[2], [2]]", "[[-.5], [2.2]]", "[[False], [True]]"]
|
||||
|
||||
dtypes = ["", ", dtype=torch.float", ", dtype=torch.double", ", dtype=torch.half",
|
||||
|
|
@ -10307,6 +10307,11 @@ a")
|
|||
self.assertEqual(r.dtype, torch.float)
|
||||
self.assertEqual(torch.zeros([1, 1, 2], dtype=torch.float), r)
|
||||
|
||||
def fn():
|
||||
return torch.zeros((1, 2, 3))
|
||||
|
||||
self.checkScript(fn, ())
|
||||
|
||||
def test_vararg_zeros(self):
|
||||
def foo():
|
||||
return torch.zeros(3, 4, 5, dtype=torch.int)
|
||||
|
|
|
|||
|
|
@ -226,8 +226,8 @@ bool Operator::matches(const Node* node) const {
|
|||
TypeEnv type_env;
|
||||
for (size_t i = 0; i < formals.size(); ++i) {
|
||||
auto formal = formals[i].type();
|
||||
const MatchTypeReturn matched_type =
|
||||
matchTypeVariables(formal, actuals[i]->type(), type_env);
|
||||
const MatchTypeReturn matched_type = matchTypeVariables(
|
||||
formal, actuals[i]->type(), type_env);
|
||||
if (!matched_type.success()) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -130,8 +130,8 @@ static Value* tryMatchArgument(
|
|||
}
|
||||
|
||||
// Resolve VarType variables
|
||||
const MatchTypeReturn matched =
|
||||
matchTypeVariables(arg.type(), value->type(), type_env);
|
||||
const MatchTypeReturn matched = matchTypeVariables(
|
||||
arg.type(), value->type(), type_env);
|
||||
if (!matched.success()) {
|
||||
if (failure_messages) {
|
||||
err() << "Could not match type " << value->type()->python_str() << " to "
|
||||
|
|
@ -242,11 +242,16 @@ static bool varargsCanBeUsedAsList(
|
|||
// The formal must be a list
|
||||
bool argument_is_list = arg.type()->kind() == TypeKind::ListType;
|
||||
|
||||
// matching varargs of typevar list nyi
|
||||
bool typevar_list = argument_is_list &&
|
||||
arg.type()->cast<ListType>()->getElementType()->cast<VarType>();
|
||||
|
||||
// it must not be a broadcasting list like int[3],
|
||||
// otherwise a single int is a valid input
|
||||
bool arg_is_broadcasting_list = bool(arg.N());
|
||||
|
||||
return is_last_argument && argument_is_list & !arg_is_broadcasting_list;
|
||||
return is_last_argument && argument_is_list & !arg_is_broadcasting_list &&
|
||||
!typevar_list;
|
||||
}
|
||||
|
||||
c10::optional<MatchedSchema> tryMatchSchema(
|
||||
|
|
|
|||
Loading…
Reference in a new issue