Improve argument printing (#87601)

No more "expected tuple but got tuple".  We appropriately
grovel in the list/tuple for the element that mismatched
and report what exactly twinged the failure.

invalid_arguments.cpp is a shitshow so I did something
slapdash to get it not completely horrible.  See
https://github.com/pytorch/pytorch/issues/87514 for more context.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87601
Approved by: https://github.com/Chillee
This commit is contained in:
albanD 2022-10-24 15:37:20 -04:00 committed by PyTorch MergeBot
parent 72ec1b5fc1
commit 3263bd24be
4 changed files with 130 additions and 21 deletions

View file

@ -19,6 +19,46 @@ class IntListWrapperModule(torch.nn.Module):
class TestNativeFunctions(TestCase):
def _lists_with_str(self):
return [
("foo",),
(2, "foo"),
("foo", 3),
["foo"],
[2, "foo"],
["foo", 3],
"foo",
]
def _test_raises_str_typeerror(self, fn):
for arg in self._lists_with_str():
self.assertRaisesRegex(TypeError, "str", lambda: fn(arg))
try:
fn(arg)
except TypeError as e:
print(e)
def test_symintlist_error(self):
x = torch.randn(1)
self._test_raises_str_typeerror(lambda arg: torch._C._nn.pad(x, arg))
def test_vararg_symintlist_error(self):
self._test_raises_str_typeerror(lambda arg: torch.rand(arg))
self._test_raises_str_typeerror(lambda arg: torch.rand(*arg))
def test_symintlist_error_with_overload_but_is_unique(self):
x = torch.randn(1)
y = torch.randn(1)
self._test_raises_str_typeerror(lambda arg: x.set_(y, 0, arg))
def test_symintlist_error_with_overload(self):
x = torch.randn(1)
self._test_raises_str_typeerror(lambda arg: x.view(arg))
def test_intlist_error_with_overload(self):
x = torch.randn(1)
self._test_raises_str_typeerror(lambda arg: torch._C._nn.pad(x, arg))
#
# optional float list
#
@ -113,7 +153,7 @@ class TestNativeFunctions(TestCase):
self.do_test_optional_intlist_with_module(fake_module)
def test_optional_intlist_invalid(self):
with self.assertRaisesRegex(TypeError, "must be .* not"):
with self.assertRaisesRegex(TypeError, "must be .* but found"):
IntListWrapperModule()(torch.zeros(1), [0.5])
with self.assertRaisesRegex(RuntimeError, "value of type .* instead found type"):

View file

@ -272,7 +272,34 @@ std::string _formattedArgDesc(
result += red;
if (is_kwarg)
result += option.arguments[i].name + "=";
result += py_typename(arg);
bool is_tuple = PyTuple_Check(arg);
if (is_tuple || PyList_Check(arg)) {
result += py_typename(arg) + " of ";
auto num_elements = PySequence_Length(arg);
if (is_tuple) {
result += "(";
} else {
result += "[";
}
for (const auto i : c10::irange(num_elements)) {
if (i != 0) {
result += ", ";
}
result += py_typename(
py::reinterpret_steal<py::object>(PySequence_GetItem(arg, i))
.ptr());
}
if (is_tuple) {
if (num_elements == 1) {
result += ",";
}
result += ")";
} else {
result += "]";
}
} else {
result += py_typename(arg);
}
if (is_matching)
result += reset_green;
else

View file

@ -664,7 +664,10 @@ bool is_float_or_complex_list(PyObject* obj) {
return true;
}
static bool is_int_list(PyObject* obj, int broadcast_size) {
static bool is_int_list(
PyObject* obj,
int broadcast_size,
int64_t* failed_idx = nullptr) {
if (PyTuple_Check(obj) || PyList_Check(obj)) {
auto len = PySequence_Size(obj);
if (len == 0) {
@ -684,6 +687,9 @@ static bool is_int_list(PyObject* obj, int broadcast_size) {
for (int i = 1; i < len; i++) {
if (torch::is_symint_node(
py::reinterpret_steal<py::object>(PySequence_GetItem(obj, i)))) {
if (failed_idx != nullptr) {
*failed_idx = i;
}
return false;
}
}
@ -694,9 +700,13 @@ static bool is_int_list(PyObject* obj, int broadcast_size) {
// NOTE: JIT tracer allows arbitrary scalar tensors to act as ints
// in an intlist argument. Even float or complex scalar tensors.
return (
jit::tracer::isTracing() && THPVariable_Check(item.ptr()) &&
THPVariable_Unpack(item.ptr()).sizes() == c10::IntArrayRef{});
bool r =
(jit::tracer::isTracing() && THPVariable_Check(item.ptr()) &&
THPVariable_Unpack(item.ptr()).sizes() == c10::IntArrayRef{});
if (!r && failed_idx != nullptr) {
*failed_idx = 0;
}
return r;
}
// if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single
// int
@ -711,7 +721,10 @@ static bool is_int_or_symint(PyObject* obj) {
return torch::is_symint_node(py::handle(obj)) || THPUtils_checkIndex(obj);
}
static bool is_int_or_symint_list(PyObject* obj, int broadcast_size) {
static bool is_int_or_symint_list(
PyObject* obj,
int broadcast_size,
int64_t* failed_idx = nullptr) {
if (PyTuple_Check(obj) || PyList_Check(obj)) {
if (PySequence_Size(obj) == 0) {
return true;
@ -723,9 +736,13 @@ static bool is_int_or_symint_list(PyObject* obj, int broadcast_size) {
}
// NOTE: JIT tracer allows arbitrary scalar tensors to act as ints
// in an intlist argument. Even float or complex scalar tensors.
return (
jit::tracer::isTracing() && THPVariable_Check(item.ptr()) &&
THPVariable_Unpack(item.ptr()).sizes() == c10::IntArrayRef{});
bool r =
(jit::tracer::isTracing() && THPVariable_Check(item.ptr()) &&
THPVariable_Unpack(item.ptr()).sizes() == c10::IntArrayRef{});
if (!r && failed_idx != nullptr) {
*failed_idx = 0;
}
return r;
}
// if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single
// int
@ -736,7 +753,8 @@ static bool is_int_or_symint_list(PyObject* obj, int broadcast_size) {
auto FunctionParameter::check(
PyObject* obj,
std::vector<py::handle>& overloaded_args,
int argnum) -> bool {
int argnum,
int64_t* failed_idx) -> bool {
switch (type_) {
case ParameterType::TENSOR: {
if (is_tensor_and_append_overloaded(obj, &overloaded_args)) {
@ -793,7 +811,7 @@ auto FunctionParameter::check(
obj, &overloaded_args, argnum, true /* throw_error */);
}
case ParameterType::INT_LIST:
return is_int_list(obj, size);
return is_int_list(obj, size, failed_idx);
case ParameterType::FLOAT_LIST:
return is_float_or_complex_list(obj);
case ParameterType::GENERATOR:
@ -824,12 +842,13 @@ auto FunctionParameter::check(
case ParameterType::SYM_INT:
return is_int_or_symint(obj);
case ParameterType::SYM_INT_LIST:
return is_int_or_symint_list(obj, size);
return is_int_or_symint_list(obj, size, failed_idx);
default:
throw std::runtime_error("unknown parameter type");
}
}
// WARNING: these strings are parsed invalid_arguments.cpp
std::string FunctionParameter::type_name() const {
switch (type_) {
case ParameterType::TENSOR:
@ -837,9 +856,10 @@ std::string FunctionParameter::type_name() const {
case ParameterType::SCALAR:
return "Number";
case ParameterType::INT64:
return "int";
// NB: SymInt is intentionally not mentioned here, as conventional user
// use will only know about ints
case ParameterType::SYM_INT:
return "SymInt";
return "int";
case ParameterType::DOUBLE:
return "float";
case ParameterType::COMPLEX:
@ -877,7 +897,7 @@ std::string FunctionParameter::type_name() const {
case ParameterType::SCALAR_LIST:
return "tuple of Scalars";
case ParameterType::SYM_INT_LIST:
return "tuple of SymInts";
return "tuple of ints";
default:
throw std::runtime_error("unknown parameter type");
}
@ -1341,6 +1361,8 @@ bool FunctionSignature::parse(
is_kwd = true;
}
int64_t failed_idx = -1;
bool varargs_eligible = allow_varargs_intlist && arg_pos == 0 && !is_kwd;
if ((!obj && param.optional) || (obj == Py_None && param.allow_none)) {
dst[i++] = nullptr;
} else if (!obj) {
@ -1349,15 +1371,16 @@ bool FunctionSignature::parse(
missing_args(*this, i);
}
return false;
} else if (param.check(obj, this->overloaded_args, i)) {
} else if (param.check(obj, this->overloaded_args, i, &failed_idx)) {
dst[i++] = obj;
// XXX: the Variable check is necessary because sizes become tensors when
// tracer is enabled. This behavior easily leads to ambiguities, and we
// should avoid having complex signatures that make use of it...
} else if (
allow_varargs_intlist && arg_pos == 0 && !is_kwd &&
((int_list_overload ? is_int_list(args, param.size)
: is_int_or_symint_list(args, param.size)))) {
varargs_eligible &&
((int_list_overload
? is_int_list(args, param.size, &failed_idx)
: is_int_or_symint_list(args, param.size, &failed_idx)))) {
// take all positional arguments as this parameter
// e.g. permute(1, 2, 3) -> permute((1, 2, 3))
dst[i++] = args;
@ -1374,6 +1397,24 @@ bool FunctionSignature::parse(
Py_TYPE(obj)->tp_name);
} else {
// foo(): argument 'other' (position 2) must be str, not int
if (failed_idx != -1) {
if (!(PyTuple_Check(obj) || PyList_Check(obj))) {
TORCH_INTERNAL_ASSERT(varargs_eligible);
obj = args;
}
TORCH_INTERNAL_ASSERT(failed_idx < PySequence_Size(obj));
throw TypeError(
"%s(): argument '%s' (position %ld) must be %s, but found element of type %s at pos %ld",
name.c_str(),
param.name.c_str(),
static_cast<long>(arg_pos + 1),
param.type_name().c_str(),
Py_TYPE(py::reinterpret_steal<py::object>(
PySequence_GetItem(obj, failed_idx))
.ptr())
->tp_name,
static_cast<long>(failed_idx));
}
throw TypeError(
"%s(): argument '%s' (position %ld) must be %s, not %s",
name.c_str(),

View file

@ -382,7 +382,8 @@ struct FunctionParameter {
bool check(
PyObject* obj,
std::vector<py::handle>& overloaded_args,
int argnum);
int argnum,
int64_t* failed_idx = nullptr);
void set_default_str(const std::string& str);
std::string type_name() const;