mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Update base for Update on "[dynamo][not ready] polyfill infra for classes"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
This commit is contained in:
parent
7311b894c0
commit
e3da3d7107
2 changed files with 15 additions and 18 deletions
|
|
@ -293,9 +293,20 @@ class SideEffects:
|
|||
return variable
|
||||
|
||||
def get_variable_cls(self, user_cls):
|
||||
from torch.overrides import TorchFunctionMode
|
||||
|
||||
from .variables.ctx_manager import GenericContextWrappingVariable
|
||||
from .variables.torch_function import TorchFunctionModeVariable
|
||||
|
||||
variable_cls: type[
|
||||
variables.UserDefinedObjectVariable
|
||||
] = variables.UserDefinedObjectVariable
|
||||
if issubclass(
|
||||
user_cls, TorchFunctionMode
|
||||
) and TorchFunctionModeVariable.is_supported_torch_function_mode(user_cls):
|
||||
variable_cls = TorchFunctionModeVariable
|
||||
elif hasattr(user_cls, "__enter__") and hasattr(user_cls, "__exit__"):
|
||||
variable_cls = GenericContextWrappingVariable
|
||||
if issubclass(user_cls, torch.nn.Module):
|
||||
variable_cls = variables.UnspecializedNNModuleVariable
|
||||
elif issubclass(user_cls, (dict, collections.OrderedDict)):
|
||||
|
|
|
|||
|
|
@ -480,23 +480,10 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||
and self.source
|
||||
and not is_forbidden_context_manager(self.value)
|
||||
):
|
||||
from torch.overrides import TorchFunctionMode
|
||||
|
||||
from .ctx_manager import GenericContextWrappingVariable
|
||||
from .functions import (
|
||||
BaseUserFunctionVariable,
|
||||
FunctionDecoratedByContextlibContextManagerVariable,
|
||||
)
|
||||
from .torch_function import TorchFunctionModeVariable
|
||||
|
||||
if issubclass(
|
||||
self.value, TorchFunctionMode
|
||||
) and TorchFunctionModeVariable.is_supported_torch_function_mode(
|
||||
self.value
|
||||
):
|
||||
var_cls = TorchFunctionModeVariable
|
||||
else:
|
||||
var_cls = GenericContextWrappingVariable
|
||||
|
||||
# graph break on any contextlib.* that it is not contextlib.contextmanager
|
||||
# Some of the APIs below are not supported because they rely on features
|
||||
|
|
@ -533,11 +520,10 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||
)
|
||||
] + args[1:]
|
||||
|
||||
options = {}
|
||||
options["base_cls_vt"] = variables.BuiltinVariable(object)
|
||||
options["init_args"] = []
|
||||
cm_obj = tx.output.side_effects.track_object_new(
|
||||
self.source, self.value, var_cls, options
|
||||
cm_obj = tx.output.side_effects.track_new_user_defined_object(
|
||||
variables.BuiltinVariable(object),
|
||||
self,
|
||||
args,
|
||||
)
|
||||
cm_obj.call_method(tx, "__init__", args, kwargs)
|
||||
return cm_obj
|
||||
|
|
|
|||
Loading…
Reference in a new issue