diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 2b62c0c1acc..663ff5b20a6 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -41,6 +41,7 @@ from .functions import ( UserMethodVariable, wrap_bound_arg, ) +from .nn_module import UnspecializedNNModuleVariable from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObjectVariable @@ -393,18 +394,20 @@ class InspectSignatureVariable(VariableTracker): super().__init__(**kwargs) self.inspected = inspected + try: + if hasattr(self.inspected, "get_function"): + self.fn = self.inspected.get_function() + elif isinstance(self.inspected, UnspecializedNNModuleVariable): + self.fn = self.inspected.value + else: + self.fn = self.inspected.as_python_constant() + except NotImplementedError: + unimplemented("inspect.signature with non-constant function") + + self.signature = inspect.signature(self.fn) + self.parameters = list(self.signature.parameters.items()) if isinstance(self.inspected, UserMethodVariable): - self.fn = self.inspected.get_function() - self.signature = inspect.signature(self.fn) - self.parameters = list(self.signature.parameters.items())[1:] - elif isinstance(self.inspected, UserFunctionVariable): - self.fn = self.inspected.get_function() - self.signature = inspect.signature(self.fn) - self.parameters = list(self.signature.parameters.items()) - else: - self.fn = self.inspected.as_python_constant() - self.signature = inspect.signature(self.fn) - self.parameters = list(self.signature.parameters.items()) + self.parameters = self.parameters[1:] def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": if name == "parameters":