mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Dynamo x autograd.Function supports setup_context (#124802)
Fixes part of #118397 Pull Request resolved: https://github.com/pytorch/pytorch/pull/124802 Approved by: https://github.com/zou3519
This commit is contained in:
parent
a866bfff45
commit
ce503c1b40
7 changed files with 79 additions and 20 deletions
|
|
@ -253,11 +253,11 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
|||
|
||||
def test_linear_setup_context(self):
|
||||
model = ModuleLinear()
|
||||
opt_model = torch._dynamo.optimize("eager")(model)
|
||||
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
|
||||
input = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
|
||||
weight = torch.randn(3, 2, dtype=torch.double, requires_grad=True)
|
||||
optim_result = opt_model(input, weight)
|
||||
eager_result = model(input, weight)
|
||||
optim_result = opt_model(input, weight)
|
||||
self.assertEqual(optim_result, eager_result)
|
||||
|
||||
def test_materialize_grad(self):
|
||||
|
|
|
|||
|
|
@ -3205,6 +3205,7 @@ MOD_INLINELIST = {
|
|||
"torch._dynamo.comptime",
|
||||
"torch._dynamo.polyfill",
|
||||
"torch._functorch.vmap",
|
||||
"torch._functorch.autograd_function",
|
||||
"torch._library.custom_ops",
|
||||
"torch._functorch.eager_transforms",
|
||||
"torch._inductor.test_operators",
|
||||
|
|
|
|||
|
|
@ -1628,13 +1628,12 @@ class AutogradFunctionApplyVariable(VariableTracker):
|
|||
fwd_src = AttrSource(self.parent_source, member="forward")
|
||||
ctx = AutogradFunctionContextVariable.create(tx, args, kwargs)
|
||||
if isinstance(self.fwd_graph, types.FunctionType):
|
||||
fwd_fn = UserFunctionVariable(self.fwd_graph, source=fwd_src)
|
||||
fwd_fn = UserFunctionVariable(self.fwd_graph)
|
||||
fwd_args = [ctx, *args]
|
||||
elif isinstance(self.fwd_graph, types.MethodType):
|
||||
fwd_fn = UserMethodVariable(
|
||||
self.fwd_graph.__func__,
|
||||
UserDefinedClassVariable(self.fwd_graph.__class__),
|
||||
source=fwd_src,
|
||||
)
|
||||
fwd_args = [fwd_fn.obj, ctx, *args]
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -357,14 +357,19 @@ class AutogradFunctionVariable(VariableTracker):
|
|||
and torch.is_grad_enabled()
|
||||
and config.capture_autograd_function
|
||||
):
|
||||
# Note - this is the same check used in autograd/function.py, except inverted.
|
||||
# If we want to support functorch transforms here, we will need to enable this.
|
||||
if (
|
||||
self.fn_cls.setup_context
|
||||
!= torch.autograd.function._SingleLevelFunction.setup_context
|
||||
):
|
||||
unimplemented(
|
||||
"NYI - autograd.Function with custom setup_context method"
|
||||
from torch._functorch.autograd_function import (
|
||||
autograd_function_forward_rewritten,
|
||||
)
|
||||
from torch.autograd.function import _is_setup_context_defined
|
||||
|
||||
forward_fn = self.fn_cls.forward
|
||||
|
||||
is_setup_ctx_defined = _is_setup_context_defined(self.fn_cls.setup_context)
|
||||
if is_setup_ctx_defined:
|
||||
# If setup_context is defined, we generate a new forward function which includes
|
||||
# the original forward and setup_context function, and trace the new forward function.
|
||||
forward_fn = autograd_function_forward_rewritten(
|
||||
self.fn_cls.forward, self.fn_cls.setup_context
|
||||
)
|
||||
|
||||
vjp_fn = self.fn_cls.vjp # type: ignore[attr-defined]
|
||||
|
|
@ -383,12 +388,25 @@ class AutogradFunctionVariable(VariableTracker):
|
|||
tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__
|
||||
)
|
||||
|
||||
return AutogradFunctionApplyVariable(
|
||||
self.fn_cls.forward,
|
||||
val = AutogradFunctionApplyVariable(
|
||||
forward_fn,
|
||||
self.fn_cls.backward,
|
||||
source,
|
||||
source=AttrSource(source, member="apply"),
|
||||
).call_function(tx, args, kwargs)
|
||||
# Inside of AutogradFunctionApplyVariable.call_function, we use sourceless variable wrapping
|
||||
# the forward function, as we don't want to generate guards for new_forward.__closure__
|
||||
# if forward is rewritten by autograd_function_forward_rewritten.
|
||||
# But we still need to generate correct guards for the original forward and setup_context
|
||||
# functions, so we have to add guards manually.
|
||||
if self.source:
|
||||
fwd_src = AttrSource(self.source, "forward")
|
||||
install_guard(fwd_src.make_guard(GuardBuilder.FUNCTION_MATCH))
|
||||
if is_setup_ctx_defined:
|
||||
setup_ctx_src = AttrSource(self.source, "setup_context")
|
||||
install_guard(setup_ctx_src.make_guard(GuardBuilder.FUNCTION_MATCH))
|
||||
|
||||
return val
|
||||
|
||||
if self.source:
|
||||
source = AttrSource(self.source, "forward")
|
||||
|
|
@ -443,7 +461,32 @@ class AutogradFunctionVariable(VariableTracker):
|
|||
return self.call_apply(tx, args, kwargs)
|
||||
|
||||
else:
|
||||
unimplemented(f"Unsupported method: {name}")
|
||||
from .. import trace_rules
|
||||
|
||||
source = AttrSource(self.source, name) if self.source is not None else None
|
||||
try:
|
||||
obj = inspect.getattr_static(self.fn_cls, name)
|
||||
except AttributeError:
|
||||
obj = None
|
||||
|
||||
if isinstance(obj, staticmethod):
|
||||
func = obj.__get__(self.fn_cls)
|
||||
if source is not None:
|
||||
return (
|
||||
trace_rules.lookup(func)
|
||||
.create_with_source(func, source=source)
|
||||
.call_function(tx, args, kwargs)
|
||||
)
|
||||
else:
|
||||
return trace_rules.lookup(func)(func).call_function(
|
||||
tx, args, kwargs
|
||||
)
|
||||
elif isinstance(obj, classmethod):
|
||||
return variables.UserMethodVariable(
|
||||
obj.__func__, self, source=source
|
||||
).call_function(tx, args, kwargs)
|
||||
else:
|
||||
unimplemented(f"Unsupported method: {name}")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import collections
|
||||
import contextlib
|
||||
import enum
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
|
|
@ -107,7 +108,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||
|
||||
def var_getattr(self, tx, name: str) -> "VariableTracker":
|
||||
from .. import trace_rules
|
||||
from . import ConstantVariable
|
||||
from . import ConstantVariable, EnumVariable
|
||||
from .builder import VariableBuilder
|
||||
|
||||
if name == "__name__":
|
||||
|
|
@ -144,14 +145,16 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||
if self.value is collections.OrderedDict and name == "fromkeys":
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
if name in getattr(self.value, "__dict__", {}) or (
|
||||
if ConstantVariable.is_literal(obj):
|
||||
return ConstantVariable.create(obj)
|
||||
elif isinstance(obj, enum.Enum):
|
||||
return EnumVariable(obj)
|
||||
elif name in getattr(self.value, "__dict__", {}) or (
|
||||
self.value.__module__.startswith("torch.")
|
||||
or self.value.__module__ == "torch"
|
||||
):
|
||||
if source:
|
||||
return VariableBuilder(tx, source)(obj)
|
||||
elif ConstantVariable.is_literal(obj):
|
||||
return ConstantVariable.create(obj)
|
||||
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
|
|
|
|||
|
|
@ -682,6 +682,15 @@ def reductify_leaf(
|
|||
return grad_input
|
||||
|
||||
|
||||
def autograd_function_forward_rewritten(original_forward, original_setup_context):
|
||||
def new_forward(ctx, *args, **kwargs):
|
||||
output = original_forward(*args, **kwargs)
|
||||
original_setup_context(ctx, args, output)
|
||||
return output
|
||||
|
||||
return new_forward
|
||||
|
||||
|
||||
class AutogradFunctionApply(HigherOrderOperator):
|
||||
def __init__(self):
|
||||
super().__init__("autograd_function_apply")
|
||||
|
|
|
|||
|
|
@ -561,7 +561,7 @@ class Function(_SingleLevelFunction):
|
|||
|
||||
return bound_args.args
|
||||
|
||||
is_setup_ctx_defined = cls.setup_context != _SingleLevelFunction.setup_context
|
||||
is_setup_ctx_defined = _is_setup_context_defined(cls.setup_context)
|
||||
if is_setup_ctx_defined:
|
||||
args = bind_default_args(cls.forward, *args, **kwargs)
|
||||
|
||||
|
|
@ -585,6 +585,10 @@ class Function(_SingleLevelFunction):
|
|||
return (ctx._autograd_function_id,)
|
||||
|
||||
|
||||
def _is_setup_context_defined(fn):
|
||||
return fn != _SingleLevelFunction.setup_context
|
||||
|
||||
|
||||
def once_differentiable(fn):
|
||||
@functools.wraps(fn)
|
||||
def wrapper(ctx, *args):
|
||||
|
|
|
|||
Loading…
Reference in a new issue