mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
[dynamo] fix crash in InspectSignatureVariable (#136010)
Fix crash that was happening in https://github.com/pytorch/pytorch/issues/128095, because we were trying to extract a constant incorrectly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136010 Approved by: https://github.com/yanboliang, https://github.com/anijain2305, https://github.com/jansel
This commit is contained in:
parent
f2b0fc89f2
commit
e037bb326f
1 changed files with 14 additions and 11 deletions
|
|
@ -41,6 +41,7 @@ from .functions import (
|
|||
UserMethodVariable,
|
||||
wrap_bound_arg,
|
||||
)
|
||||
from .nn_module import UnspecializedNNModuleVariable
|
||||
from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObjectVariable
|
||||
|
||||
|
||||
|
|
@ -393,18 +394,20 @@ class InspectSignatureVariable(VariableTracker):
|
|||
super().__init__(**kwargs)
|
||||
self.inspected = inspected
|
||||
|
||||
try:
|
||||
if hasattr(self.inspected, "get_function"):
|
||||
self.fn = self.inspected.get_function()
|
||||
elif isinstance(self.inspected, UnspecializedNNModuleVariable):
|
||||
self.fn = self.inspected.value
|
||||
else:
|
||||
self.fn = self.inspected.as_python_constant()
|
||||
except NotImplementedError:
|
||||
unimplemented("inspect.signature with non-constant function")
|
||||
|
||||
self.signature = inspect.signature(self.fn)
|
||||
self.parameters = list(self.signature.parameters.items())
|
||||
if isinstance(self.inspected, UserMethodVariable):
|
||||
self.fn = self.inspected.get_function()
|
||||
self.signature = inspect.signature(self.fn)
|
||||
self.parameters = list(self.signature.parameters.items())[1:]
|
||||
elif isinstance(self.inspected, UserFunctionVariable):
|
||||
self.fn = self.inspected.get_function()
|
||||
self.signature = inspect.signature(self.fn)
|
||||
self.parameters = list(self.signature.parameters.items())
|
||||
else:
|
||||
self.fn = self.inspected.as_python_constant()
|
||||
self.signature = inspect.signature(self.fn)
|
||||
self.parameters = list(self.signature.parameters.items())
|
||||
self.parameters = self.parameters[1:]
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker":
|
||||
if name == "parameters":
|
||||
|
|
|
|||
Loading…
Reference in a new issue