Improve IMethod::getArgumentNames to deal with empty argument names list (#62947)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62947

This diff improved IMethod::getArgumentNames to deal with empty argument names list.

Test Plan:
buck test mode/dev //caffe2/caffe2/fb/predictor:pytorch_predictor_test -- PyTorchDeployPredictor.GetEmptyArgumentNamesValidationMode
buck test mode/dev //caffe2/caffe2/fb/predictor:pytorch_predictor_test -- PyTorchDeployPredictor.GetEmptyArgumentNamesRealMode

Reviewed By: wconstab

Differential Revision: D30179974

fbshipit-source-id: c7aec35c360a73318867c5b77ebfec3affee47e3
This commit is contained in:
Jiewen Tan 2021-08-11 16:42:34 -07:00 committed by Facebook GitHub Bot
parent 5cf32c1d09
commit 04caef8e1d
4 changed files with 15 additions and 4 deletions

View file

@ -38,6 +38,7 @@ class IMethod {
virtual void setArgumentNames(std::vector<std::string>& argumentNames) const = 0;
private:
mutable bool isArgumentNamesInitialized_ { false };
mutable std::vector<std::string> argumentNames_;
};

View file

@ -4,11 +4,11 @@ namespace torch {
const std::vector<std::string>& IMethod::getArgumentNames() const
{
// TODO(jwtan): Deal with empty parameter list.
if (!argumentNames_.empty()) {
if (isArgumentNamesInitialized_) {
return argumentNames_;
}
isArgumentNamesInitialized_ = true;
setArgumentNames(argumentNames_);
return argumentNames_;
}

View file

@ -78,11 +78,18 @@ InterpreterManager::InterpreterManager(size_t n_interp) : resources_(n_interp) {
}
// Pre-registered modules.
// Since torch::deploy::Obj.toIValue cannot infer empty list, we hack it to
// return None for empty list.
// TODO(jwtan): Make the discovery of these modules easier.
register_module_source(
"GetArgumentNamesModule",
"from inspect import signature\n"
"def getArgumentNames(function): return list(signature(function).parameters.keys())\n");
"from typing import Callable, Optional\n"
"def getArgumentNames(function: Callable) -> Optional[list]:\n"
" names = list(signature(function).parameters.keys())\n"
" if len(names) == 0:\n"
" return None\n"
" return names\n");
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
@ -291,6 +298,10 @@ void PythonMethodWrapper::setArgumentNames(
auto iArgumentNames =
session.global("GetArgumentNamesModule", "getArgumentNames")({method})
.toIValue();
if (iArgumentNames.isNone()) {
return;
}
TORCH_INTERNAL_ASSERT(iArgumentNames.isList());
auto argumentNames = iArgumentNames.toListRef();

View file

@ -279,7 +279,6 @@ InferredType tryToInferContainerType(py::handle input);
// Try to infer the type of a Python object
// The type cannot be inferred if:
// input is a None
// input is an empty container (list, dict)
// input is an list with element types that cannot be unified
// input is an dict with key or value types that cannot be unified