mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Improve IMethod::getArgumentNames to deal with empty argument names list (#62947)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62947 This diff improved IMethod::getArgumentNames to deal with empty argument names list. Test Plan: buck test mode/dev //caffe2/caffe2/fb/predictor:pytorch_predictor_test -- PyTorchDeployPredictor.GetEmptyArgumentNamesValidationMode buck test mode/dev //caffe2/caffe2/fb/predictor:pytorch_predictor_test -- PyTorchDeployPredictor.GetEmptyArgumentNamesRealMode Reviewed By: wconstab Differential Revision: D30179974 fbshipit-source-id: c7aec35c360a73318867c5b77ebfec3affee47e3
This commit is contained in:
parent
5cf32c1d09
commit
04caef8e1d
4 changed files with 15 additions and 4 deletions
|
|
@ -38,6 +38,7 @@ class IMethod {
|
|||
virtual void setArgumentNames(std::vector<std::string>& argumentNames) const = 0;
|
||||
|
||||
private:
|
||||
mutable bool isArgumentNamesInitialized_ { false };
|
||||
mutable std::vector<std::string> argumentNames_;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -4,11 +4,11 @@ namespace torch {
|
|||
|
||||
const std::vector<std::string>& IMethod::getArgumentNames() const
|
||||
{
|
||||
// TODO(jwtan): Deal with empty parameter list.
|
||||
if (!argumentNames_.empty()) {
|
||||
if (isArgumentNamesInitialized_) {
|
||||
return argumentNames_;
|
||||
}
|
||||
|
||||
isArgumentNamesInitialized_ = true;
|
||||
setArgumentNames(argumentNames_);
|
||||
return argumentNames_;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -78,11 +78,18 @@ InterpreterManager::InterpreterManager(size_t n_interp) : resources_(n_interp) {
|
|||
}
|
||||
|
||||
// Pre-registered modules.
|
||||
// Since torch::deploy::Obj.toIValue cannot infer empty list, we hack it to
|
||||
// return None for empty list.
|
||||
// TODO(jwtan): Make the discovery of these modules easier.
|
||||
register_module_source(
|
||||
"GetArgumentNamesModule",
|
||||
"from inspect import signature\n"
|
||||
"def getArgumentNames(function): return list(signature(function).parameters.keys())\n");
|
||||
"from typing import Callable, Optional\n"
|
||||
"def getArgumentNames(function: Callable) -> Optional[list]:\n"
|
||||
" names = list(signature(function).parameters.keys())\n"
|
||||
" if len(names) == 0:\n"
|
||||
" return None\n"
|
||||
" return names\n");
|
||||
TORCH_DEPLOY_SAFE_CATCH_RETHROW
|
||||
}
|
||||
|
||||
|
|
@ -291,6 +298,10 @@ void PythonMethodWrapper::setArgumentNames(
|
|||
auto iArgumentNames =
|
||||
session.global("GetArgumentNamesModule", "getArgumentNames")({method})
|
||||
.toIValue();
|
||||
if (iArgumentNames.isNone()) {
|
||||
return;
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(iArgumentNames.isList());
|
||||
auto argumentNames = iArgumentNames.toListRef();
|
||||
|
||||
|
|
|
|||
|
|
@ -279,7 +279,6 @@ InferredType tryToInferContainerType(py::handle input);
|
|||
|
||||
// Try to infer the type of a Python object
|
||||
// The type cannot be inferred if:
|
||||
// input is a None
|
||||
// input is an empty container (list, dict)
|
||||
// input is an list with element types that cannot be unified
|
||||
// input is an dict with key or value types that cannot be unified
|
||||
|
|
|
|||
Loading…
Reference in a new issue