From e037bb326fdafca243bdb08023bbef93b29a4513 Mon Sep 17 00:00:00 2001 From: William Wen Date: Tue, 17 Sep 2024 00:48:24 +0000 Subject: [PATCH] [dynamo] fix crash in InspectSignatureVariable (#136010) Fix crash that was happening in https://github.com/pytorch/pytorch/issues/128095, because we were trying to extract a constant incorrectly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136010 Approved by: https://github.com/yanboliang, https://github.com/anijain2305, https://github.com/jansel --- torch/_dynamo/variables/misc.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 2b62c0c1acc..663ff5b20a6 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -41,6 +41,7 @@ from .functions import ( UserMethodVariable, wrap_bound_arg, ) +from .nn_module import UnspecializedNNModuleVariable from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObjectVariable @@ -393,18 +394,20 @@ class InspectSignatureVariable(VariableTracker): super().__init__(**kwargs) self.inspected = inspected + try: + if hasattr(self.inspected, "get_function"): + self.fn = self.inspected.get_function() + elif isinstance(self.inspected, UnspecializedNNModuleVariable): + self.fn = self.inspected.value + else: + self.fn = self.inspected.as_python_constant() + except NotImplementedError: + unimplemented("inspect.signature with non-constant function") + + self.signature = inspect.signature(self.fn) + self.parameters = list(self.signature.parameters.items()) if isinstance(self.inspected, UserMethodVariable): - self.fn = self.inspected.get_function() - self.signature = inspect.signature(self.fn) - self.parameters = list(self.signature.parameters.items())[1:] - elif isinstance(self.inspected, UserFunctionVariable): - self.fn = self.inspected.get_function() - self.signature = inspect.signature(self.fn) - self.parameters = list(self.signature.parameters.items()) - else: - self.fn = self.inspected.as_python_constant() - self.signature = inspect.signature(self.fn) - self.parameters = list(self.signature.parameters.items()) + self.parameters = self.parameters[1:] def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": if name == "parameters":