[autograd] Support GradientEdge as output for torch.autograd.grad (#127766)

This is useful for splitting grad to run in two parts while preserving intermediates:

<details>

<summary>
Click to see code
</summary>

```python
import collections
import weakref
from torch.autograd.graph import GradientEdge

def _get_grad_fn_or_grad_acc(t):
    if t.requires_grad and t.grad_fn is None:
        return t.view_as(t).grad_fn.next_functions[0][0]
    else:
        return t.grad_fn

def reverse_closure(roots, target_nodes):
    # Recurse until we reach a target node
    closure = set()
    actual_target_nodes = set()
    q: Deque = collections.deque()
    for node in roots:
        if node is not None and node not in closure:
            closure.add(node)
            q.append(node)
    while q:
        node = q.popleft()
        reverse_edges = node.metadata.get("reverse_edges", [])
        for holder_ref, idx in reverse_edges:
            ref = holder_ref()
            if ref is not None:
                raise RuntimeError("Reverse graph is no longer alive")
            fn = ref.node
            if fn in closure or fn is None:
                continue
            if fn in target_nodes:
                actual_target_nodes.add(fn)
                continue
            closure.add(fn)
            q.append(fn)
    return closure, actual_target_nodes

# Enable weak pointer
class Holder():
    def __init__(self, node):
        self.node = node

# TODO: use weak references to avoid reference cycle
def construct_reverse_graph(roots):
    q: Deque = collections.deque()
    root_seen = set()
    reverse_graph_refs = []
    for node in roots:
        if node is not None and node not in root_seen:
            q.append(node)
            root_seen.add(node)
    while q:
        node = q.popleft()
        for fn, idx in node.next_functions:
            if fn is not None:
                # Don't necessarily need to store on the graph
                reverse_edges = fn.metadata.get("reverse_edges", [])
                if len(reverse_edges) == 0:
                    q.append(fn)
                holder = Holder(node)
                holder_ref = weakref.ref(holder)
                reverse_graph_refs.append(holder)
                reverse_edges.append((holder_ref, idx))
                fn.metadata["reverse_edges"] = reverse_edges
    return reverse_graph_refs

def get_param_groups(inputs, params):
    inputs_closure, _ = reverse_closure(inputs, set())
    param_groups = dict()  # keyed on intermediates
    for i, param in enumerate(params):
        closure, intersected = reverse_closure([param], inputs_closure)
        param_group = {
            "params": set([param]),
            "intermediates": set(intersected),
        }
        for input_node in intersected:
            existing = param_groups.get(input_node, None)
            if existing is not None:
                existing["params"] = existing["params"].union(param_group["params"])
                existing["intermediates"] = existing["intermediates"].union(param_group["intermediates"])
                param_group = existing
            else:
                param_groups[input_node] = param_group

    # Sanity check: union of all param_groups params should be equal to all params
    union_params = set()
    seen_ids = set()
    unique_param_groups = []
    for param_group in param_groups.values():
        if id(param_group) not in seen_ids:
            seen_ids.add(id(param_group))
            unique_param_groups.append(param_group)
            union_params = union_params.union(param_group["params"])
    assert union_params == set(params)

    return unique_param_groups

def compute_grads_only_inputs2(roots, inps, weights):
    root_grad_fns = list(map(_get_grad_fn_or_grad_acc, roots))
    inp_grad_fns = list(map(_get_grad_fn_or_grad_acc, inps))
    weight_grad_fns = list(map(_get_grad_fn_or_grad_acc, weights))

    reverse_graph_refs = construct_reverse_graph(root_grad_fns)
    param_groups = get_param_groups(inp_grad_fns, weight_grad_fns)
    del reverse_graph_refs

    for param_group in param_groups:
        for i, intermediate in enumerate(param_group["intermediates"]):
            def get_hook(param_group, i):
                def hook(grad_inputs):
                    if param_group.get("grads", None) is None:
                        param_group["grads"] = [None] * len(param_group["intermediates"])
                    param_group["grads"][i] = grad_inputs
                return hook
            # These are always "split" nodes that we need to recompute, so
            # save their inputs.
            intermediate.register_prehook(get_hook(param_group, i))

    dinputs = torch.autograd.grad((out,), inputs=tuple(inps), grad_outputs=(torch.ones_like(out),), retain_graph=True)
    return dinputs, param_groups

def compute_grads_only_weights2(user_weights, param_groups):
    all_dweights = dict()
    for param_group in param_groups:
        # TODO: Handle case where intermediate can have multiple outputs
        intermediate_edges = tuple(GradientEdge(i, 0) for i in param_group["intermediates"])
        weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"])

        assert all(len(g) == 1 for g in param_group["grads"])
        # [NEW!] Able to pass a GradientEdge to autograd.grad as output
        # We do not need to retain_graph because... guarantee no overlap?
        print("trying to execute: ", intermediate_edges, weights_edges)
        dweights = torch.autograd.grad(intermediate_edges, weights_edges, grad_outputs=sum(param_group["grads"], tuple()))
        for w, dw in zip(param_group["params"], dweights):
            all_dweights[w] = dw
    # return grads in the original order weights were provided in
    out = []
    for w in user_weights:
        grad_acc = _get_grad_fn_or_grad_acc(w)
        out.append(all_dweights[grad_acc])
    return tuple(out)

```

</details>

```python
import torch.nn as nn

# Setup
mod1 = nn.Linear(10, 10)
mod2 = nn.Linear(10, 10)

a = torch.rand(10, requires_grad=True)

weights = tuple(mod1.parameters()) + tuple(mod2.parameters())
inps = (a,)

out = mod2(mod1(a))

class LoggingTensorMode(torch.utils._python_dispatch.TorchDispatchMode):
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        rs = func(*args, **kwargs)
        print(f"{func.__module__}.{func.__name__}")
        return rs

print(" -- SPLIT -- ")
# Compute gradients in two parts
with LoggingTensorMode():
    print("PART 1")
    dinputs, state = compute_grads_only_inputs2((out,), inps, weights)
    print("PART 2")
    dweights = compute_grads_only_weights2(weights, state)

out = mod2(mod1(a))

print(" -- REF -- ")

# Compare with reference
with LoggingTensorMode():
    ref_all_gradients = torch.autograd.grad(out, inputs=tuple(inps) + weights, grad_outputs=(torch.ones_like(out),))

for actual, ref in zip(dinputs + dweights, ref_all_gradients):
    print(torch.allclose(actual, ref))

```

<img width="598" alt="image" src="https://github.com/pytorch/pytorch/assets/13428986/3681b8a7-3ab4-4d1d-a836-abef6913e671">

```
PART 1
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.ones_like.default
V0603 10:17:21.590878 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1ee160> with grad_outputs: [f32[10]]
torch._ops.aten.view.default
V0603 10:17:21.591204 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1ee0d0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
V0603 10:17:21.591578 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x100d7ae50> with grad_outputs: [f32[1, 10]]
torch._ops.aten.view.default
V0603 10:17:21.591747 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1e4a60> with grad_outputs: [f32[10]]
torch._ops.aten.view.default
V0603 10:17:21.591834 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1e4bb0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
V0603 10:17:21.591922 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1e4a90> with grad_outputs: [f32[1, 10]]
torch._ops.aten.view.default
PART 2
trying to execute:  (GradientEdge(node=<AddmmBackward0 object at 0x12a1e4bb0>, output_nr=0),) (GradientEdge(node=<AccumulateGrad object at 0x12a21b130>, output_nr=0), GradientEdge(node=<AccumulateGrad object at 0x12a21b7c0>, output_nr=0))
V0603 10:17:21.592223 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1e4bb0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
torch._ops.aten.t.default
torch._ops.aten.sum.dim_IntList
torch._ops.aten.view.default
V0603 10:17:21.592421 8300067520 torch/autograd/graph.py:751] Executing: <TBackward0 object at 0x12a1cad60> with grad_outputs: [f32[10, 10]]
torch._ops.aten.t.default
trying to execute:  (GradientEdge(node=<AddmmBackward0 object at 0x12a1ee0d0>, output_nr=0),) (GradientEdge(node=<AccumulateGrad object at 0x12a1e41c0>, output_nr=0), GradientEdge(node=<AccumulateGrad object at 0x12a21b670>, output_nr=0))
V0603 10:17:21.593481 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1ee0d0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
torch._ops.aten.t.default
torch._ops.aten.sum.dim_IntList
torch._ops.aten.view.default
V0603 10:17:21.593750 8300067520 torch/autograd/graph.py:751] Executing: <TBackward0 object at 0x12a21b2b0> with grad_outputs: [f32[10, 10]]
torch._ops.aten.t.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127766
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer 2024-07-16 09:29:08 -04:00 committed by PyTorch MergeBot
parent c1e7e40f24
commit 2eec02523b
10 changed files with 359 additions and 75 deletions

View file

@ -2277,6 +2277,7 @@ known_failing_tests = {
"test_grad_materialize_grads", # RuntimeError: compiled_autograd does not support create_graph
"test_grad_nonleaf", # RuntimeError: compiled_autograd does not support create_graph
"test_grad_nonleaf_many_outputs", # RuntimeError: compiled_autograd does not support create_graph
"test_gradient_edge_output", # RuntimeError: trying to backward through the graph a second time
"test_hessian_vector", # RuntimeError: compiled_autograd does not support create_graph
"test_hook_closure_cycle_use_custom_function_True_use_tensor_hook_False", # AttributeError: type object
"test_hook_closure_cycle_use_custom_function_True_use_tensor_hook_True", # AttributeError: type object

View file

@ -436,6 +436,7 @@ class TestProfilerTree(TestCase):
[memory]""",
)
@unittest.skip("https://github.com/pytorch/pytorch/issues/83606")
@unittest.skipIf(
TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite."
)
@ -478,8 +479,19 @@ class TestProfilerTree(TestCase):
<built-in function len>
torch/autograd/__init__.py(...): _tensor_or_tensors_to_tuple
torch/autograd/__init__.py(...): _make_grads
typing.py(...): inner
typing.py(...): __hash__
<built-in function hash>
typing.py(...): cast
<built-in function isinstance>
<built-in function isinstance>
<built-in function isinstance>
<built-in function isinstance>
<built-in function isinstance>
<built-in function isinstance>
<built-in method numel of Tensor object at 0xXXXXXXXXXXXX>
<built-in function isinstance>
<built-in function isinstance>
<built-in method ones_like of type object at 0xXXXXXXXXXXXX>
aten::ones_like
aten::empty_like
@ -910,6 +922,9 @@ class TestProfilerTree(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
@ProfilerTree.test
def test_profiler_experimental_tree_cuda_detailed(self):
# Do lazy imports ahead of time to avoid it showing up in the tree
import torch.nested._internal.nested_tensor
model = torch.nn.modules.Linear(1, 1, device="cuda")
model.train()
opt = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
@ -967,8 +982,19 @@ class TestProfilerTree(TestCase):
<built-in function len>
torch/autograd/__init__.py(...): _tensor_or_tensors_to_tuple
torch/autograd/__init__.py(...): _make_grads
typing.py(...): inner
typing.py(...): __hash__
<built-in function hash>
typing.py(...): cast
<built-in function isinstance>
<built-in function isinstance>
<built-in function isinstance>
<built-in function isinstance>
<built-in function isinstance>
<built-in function isinstance>
<built-in method numel of Tensor object at 0xXXXXXXXXXXXX>
<built-in function isinstance>
<built-in function isinstance>
<built-in method ones_like of type object at 0xXXXXXXXXXXXX>
aten::ones_like
aten::empty_like

View file

@ -985,6 +985,113 @@ class TestAutograd(TestCase):
torch.autograd.backward(out.sum(), inputs=(x, edge_y))
torch.autograd.backward(out.sum(), inputs=(edge_x, edge_y))
def test_grad_fn_input_metadata(self):
x = torch.rand(2, requires_grad=True, dtype=torch.float32)
y = torch.rand(2, requires_grad=True, dtype=torch.float32)
z = x * y
z_metadata = z.grad_fn._input_metadata[0]
self.assertEqual(z_metadata.shape, (2,))
self.assertEqual(z_metadata.dtype, torch.float32)
# Multiple outputs
b = torch.rand(3, 3, requires_grad=True)
var, _ = torch.var_mean(b, dim=0)
metadata_0 = var.grad_fn._input_metadata[0]
metadata_1 = var.grad_fn._input_metadata[1]
self.assertEqual(metadata_0.shape, (3,))
self.assertEqual(metadata_1.shape, (3,))
# Preserves symints
nt = torch.nested.nested_tensor(
[
torch.randn(
3,
2,
),
torch.randn(
2,
2,
),
],
layout=torch.jagged,
requires_grad=True,
)
nt_metadata = nt.clone().grad_fn._input_metadata[0]
self.assertIsInstance(nt_metadata.shape[1], torch.SymInt)
self.assertEqual(nt_metadata.shape, nt.shape)
self.assertTrue(nt_metadata.is_nested_tensor)
self.assertFalse(nt_metadata.is_cpp_nested_tensor)
self.assertEqual(nt_metadata.dtype, nt.dtype)
class Test(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output
x = torch.randn(3, 3, requires_grad=True)
x = Test.apply(x)
metadata = x.grad_fn._input_metadata[0]
self.assertEqual(metadata.shape, (3, 3))
def test_gradient_edge_output(self):
x = torch.tensor([1.0, 2.0], requires_grad=True)
def fn(x, reduce=True):
tmp = x.sin().cos()
if reduce:
tmp = tmp.sum()
out = tmp.exp().clone().sin().sum()
tmp_edge = torch.autograd.graph.get_gradient_edge(tmp)
return out, tmp_edge
# Compute fn backward in two steps
out, tmp_edge = fn(x)
(tmp_grad,) = torch.autograd.grad(out, (tmp_edge,))
(x_grad,) = torch.autograd.grad(tmp_edge, (x,), grad_outputs=(tmp_grad,))
# Compare with as if we did it in one go.
out, _ = fn(x)
(x_grad_ref,) = torch.autograd.grad(out, (x,))
self.assertEqual(x_grad, x_grad_ref)
# Incorrect case: grad_outputs not passed/implicitly None and output is
# not a scalar
out, tmp_edge = fn(x, reduce=False)
with self.assertRaisesRegex(
RuntimeError,
"grad can be implicitly created only for scalar output",
):
torch.autograd.grad(tmp_edge, (x,))
# grad_outputs is None, and output is a scalar is fine
out, tmp_edge = fn(x, reduce=True)
torch.autograd.grad(tmp_edge, (x,))
# Incorrect case: grad_outputs wrong size
out, tmp_edge = fn(x)
(tmp_grad,) = torch.autograd.grad(out, (tmp_edge,))
with self.assertRaisesRegex(RuntimeError, "Mismatch in shape"):
torch.autograd.grad(
tmp_edge, (x,), grad_outputs=torch.tensor([1.0, 2.0, 3.0, 4.0])
)
# Incorrect case: wrong dtype
out, tmp_edge = fn(x)
(tmp_grad,) = torch.autograd.grad(out, (tmp_edge,))
with self.assertRaisesRegex(RuntimeError, "required to have the same dtype"):
torch.autograd.grad(
tmp_edge,
(x,),
grad_outputs=torch.rand_like(tmp_grad, dtype=torch.complex64),
)
def test_grad_nonleaf(self):
x_init = torch.randn(2, 2, requires_grad=True)
x = x_init

View file

@ -56,12 +56,21 @@ _ShapeorNestedShape = Union[_size, Sequence[_size], torch.Tensor]
def _calculate_shape(
output: torch.Tensor, grad: torch.Tensor, is_grads_batched: bool
output: Union[torch.Tensor, graph.GradientEdge],
grad: torch.Tensor,
is_grads_batched: bool,
) -> Tuple[_ShapeorNestedShape, _ShapeorNestedShape]:
# is_same_size ensures that both tensors are either nested or non nested
# circular import
from torch.nested._internal.nested_tensor import NestedTensor
if isinstance(output, graph.GradientEdge):
# We have already checked that we are not a C++ NestedTensor
if is_grads_batched:
raise RuntimeError("Batched grads are not supported with GradientEdge")
out_metadata = output.node._input_metadata[output.output_nr]
return torch.Size(out_metadata.shape), grad.shape
if output.is_nested and not isinstance(output, NestedTensor):
if is_grads_batched:
raise RuntimeError("Batched grads are not supported with Nested Tensor.")
@ -76,27 +85,58 @@ def _calculate_shape(
def _make_grads(
outputs: Sequence[torch.Tensor],
outputs: Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]],
grads: Sequence[_OptionalTensor],
is_grads_batched: bool,
) -> Tuple[_OptionalTensor, ...]:
new_grads: List[_OptionalTensor] = []
for out, grad in zip(outputs, grads):
out = cast(Union[torch.Tensor, graph.GradientEdge], out)
out_size = None
out_device = None
if isinstance(out, graph.GradientEdge):
out_metadata = out.node._input_metadata[out.output_nr]
out_size = torch.Size(out_metadata.shape)
out_dtype = out_metadata.dtype
out_device = out_metadata.device
out_is_nested = out_metadata.is_nested_tensor
if out_metadata.is_cpp_nested_tensor:
raise RuntimeError(
"C++ NestedTensor are not supported with GradientEdge"
)
out_is_cpp_nested = False
else:
# circular import
from torch.nested._internal.nested_tensor import NestedTensor
assert isinstance(out, torch.Tensor)
out_dtype = out.dtype
out_is_nested = out.is_nested
out_is_cpp_nested = out_is_nested and not isinstance(out, NestedTensor)
if not out_is_cpp_nested:
out_size = out.shape
if isinstance(grad, torch.Tensor):
from torch.fx.experimental.symbolic_shapes import expect_true, sym_eq
first_grad = grad if not is_grads_batched else grad[0]
# TODO: We can remove this conditional once we uniformly use
# singleton int to represent jagged dimension, so that size() call
# on nested tensor works
if out.is_nested or first_grad.is_nested:
# on nested tensor works.
if out_is_cpp_nested:
assert isinstance(out, torch.Tensor)
shape_matches = torch.is_same_size(out, first_grad)
else:
# We need to do a regular size check, without going through
# the operator, to be able to handle unbacked symints
# (expect_true ensures we can deal with unbacked)
shape_matches = expect_true(sym_eq(out.size(), first_grad.size()))
assert out_size is not None
shape_matches = expect_true(sym_eq(out_size, first_grad.size()))
if not shape_matches:
out = cast(Union[torch.Tensor, graph.GradientEdge], out)
out_shape, grad_shape = _calculate_shape(
out, first_grad, is_grads_batched
)
@ -130,7 +170,7 @@ def _make_grads(
+ str(out_shape)
+ "."
)
if out.dtype.is_complex != grad.dtype.is_complex:
if out_dtype.is_complex != grad.dtype.is_complex:
raise RuntimeError(
"For complex Tensors, both grad_output and output"
" are required to have the same dtype."
@ -141,25 +181,43 @@ def _make_grads(
+ " and output["
+ str(outputs.index(out))
+ "] has a dtype of "
+ str(out.dtype)
+ str(out_dtype)
+ "."
)
new_grads.append(grad)
elif grad is None:
if out.requires_grad:
if out.numel() != 1:
if isinstance(out, graph.GradientEdge) or out.requires_grad: # type: ignore[attr-defined]
if isinstance(out, graph.GradientEdge):
assert out_size is not None
out_numel_is_1 = all(o == 1 for o in out_size)
else:
assert isinstance(out, torch.Tensor)
out_numel_is_1 = out.numel() == 1
if not out_numel_is_1:
raise RuntimeError(
"grad can be implicitly created only for scalar outputs"
)
if not out.dtype.is_floating_point:
if not out_dtype.is_floating_point:
msg = (
"grad can be implicitly created only for real scalar outputs"
f" but got {out.dtype}"
f" but got {out_dtype}"
)
raise RuntimeError(msg)
new_grads.append(
torch.ones_like(out, memory_format=torch.preserve_format)
)
if isinstance(out, graph.GradientEdge):
assert out_size is not None
assert out_device is not None
new_grads.append(
torch.ones(
out_size,
dtype=out_dtype,
device=out_device,
)
)
else:
assert isinstance(out, torch.Tensor)
new_grads.append(
torch.ones_like(out, memory_format=torch.preserve_format)
)
else:
new_grads.append(None)
else:
@ -297,7 +355,7 @@ def backward(
def grad(
outputs: _TensorOrTensors,
outputs: _TensorOrTensorsOrGradEdge,
inputs: _TensorOrTensorsOrGradEdge,
grad_outputs: Optional[_TensorOrTensors] = None,
retain_graph: Optional[bool] = None,
@ -327,7 +385,7 @@ def grad(
``torch.autograd.backward``.
Args:
outputs (sequence of Tensor): outputs of the differentiated function.
outputs (sequence of Tensor or GradientEdge): outputs of the differentiated function.
inputs (sequence of Tensor or GradientEdge): Inputs w.r.t. which the gradient will be
returned (and not accumulated into ``.grad``).
grad_outputs (sequence of Tensor): The "vector" in the vector-Jacobian product.
@ -369,21 +427,24 @@ def grad(
)
if allow_unused is None:
allow_unused = materialize_grads
t_outputs = cast(
Tuple[torch.Tensor, ...],
(outputs,) if is_tensor_like(outputs) else tuple(outputs),
)
if is_tensor_like(outputs) or isinstance(outputs, graph.GradientEdge):
outputs = cast(
Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]], (outputs,)
)
else:
outputs = tuple(outputs)
if is_tensor_like(inputs) or isinstance(inputs, graph.GradientEdge):
inputs = cast(_TensorOrTensorsOrGradEdge, (inputs,))
else:
inputs = tuple(inputs)
t_outputs = tuple(i for i in outputs if is_tensor_like(i))
t_inputs = tuple(i for i in inputs if is_tensor_like(i))
overridable_args = t_outputs + t_inputs
if has_torch_function(overridable_args):
return handle_torch_function(
grad,
overridable_args,
t_outputs,
outputs,
inputs,
grad_outputs=grad_outputs,
retain_graph=retain_graph,
@ -403,9 +464,9 @@ def grad(
stacklevel=2,
)
grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(t_outputs))
grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(outputs))
grad_outputs_ = _make_grads(
t_outputs, grad_outputs_, is_grads_batched=is_grads_batched
outputs, grad_outputs_, is_grads_batched=is_grads_batched
)
if retain_graph is None:
@ -418,7 +479,7 @@ def grad(
def vjp(gO):
return _engine_run_backward(
t_outputs,
outputs,
gO,
retain_graph,
create_graph,
@ -432,7 +493,7 @@ def grad(
)
else:
result = _engine_run_backward(
t_outputs,
outputs,
grad_outputs_,
retain_graph,
create_graph,

View file

@ -79,6 +79,11 @@ class Node(abc.ABC):
r"""Return the metadata."""
raise NotImplementedError
@property
@abc.abstractmethod
def _input_metadata(self) -> List[Any]:
raise NotImplementedError
@abc.abstractmethod
def _register_hook_dict(self, tensor: torch.Tensor) -> None:
raise NotImplementedError
@ -170,7 +175,9 @@ class Node(abc.ABC):
return NotImplemented
def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Node:
def _get_grad_fn_or_grad_acc(t: Union[torch.Tensor, "GradientEdge"]) -> Node:
if isinstance(t, GradientEdge):
return t.node
if t.requires_grad and t.grad_fn is None:
node = t.view_as(t).grad_fn.next_functions[0][0] # type: ignore[union-attr]
else:
@ -738,7 +745,7 @@ def allow_mutation_on_saved_tensors() -> (
def _register_logging_hooks_on_whole_graph(
t_outputs: Sequence[torch.Tensor],
t_outputs: Sequence[Union[torch.Tensor, GradientEdge]],
) -> Callable[[], None]:
grad_fns = list(map(_get_grad_fn_or_grad_acc, t_outputs))
@ -788,7 +795,7 @@ def _register_logging_hooks_on_whole_graph(
def _engine_run_backward(
t_outputs: Sequence[torch.Tensor],
t_outputs: Sequence[Union[torch.Tensor, GradientEdge]],
*args: Any,
**kwargs: Any,
) -> Tuple[torch.Tensor, ...]:

View file

@ -16,6 +16,7 @@
#include <torch/csrc/autograd/autograd_not_implemented_fallback.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/csrc/autograd/input_metadata.h>
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/autograd/profiler_python.h>
#include <torch/csrc/autograd/python_function.h>
@ -184,6 +185,23 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
.value("NO_GRAD_MODE", CreationMeta::NO_GRAD_MODE)
.value("INFERENCE_MODE", CreationMeta::INFERENCE_MODE);
py::class_<torch::autograd::InputMetadata>(m, "_InputMetadata")
.def_property_readonly(
"dtype",
[](const torch::autograd::InputMetadata& m) {
PyObject* raw_obj =
(PyObject*)torch::getTHPDtype(m.dtype().toScalarType());
return py::reinterpret_borrow<py::object>(raw_obj);
})
.def_property_readonly("device", &torch::autograd::InputMetadata::device)
.def_property_readonly(
"shape", &torch::autograd::InputMetadata::shape_as_dim_vector)
.def_property_readonly(
"is_nested_tensor", &torch::autograd::InputMetadata::is_nested_tensor)
.def_property_readonly(
"is_cpp_nested_tensor",
&torch::autograd::InputMetadata::is_cpp_nested_tensor);
py::class_<KinetoEvent>(m, "_KinetoEvent")
// name of the event
.def("name", [](const KinetoEvent& e) { return e.name(); })

View file

@ -210,6 +210,27 @@ PyObject* THPCppFunction_set_sequence_nr(
END_HANDLE_TH_ERRORS
}
PyObject* THPCppFunction_input_metadata(PyObject* self, void* closure) {
HANDLE_TH_ERRORS;
auto& fn = *((THPCppFunction*)self)->cdata;
const auto num_inputs =
fn.num_inputs(); // Assuming there's a method to get the number of inputs
THPObjectPtr list(PyTuple_New(num_inputs));
if (!list) {
return nullptr;
}
for (size_t i = 0; i < num_inputs; ++i) {
const auto& metadata = fn.input_metadata(i);
THPObjectPtr item(py::cast(metadata).release().ptr());
if (!item) {
return nullptr;
}
PyTuple_SET_ITEM(list.get(), i, item.release());
}
return list.release();
END_HANDLE_TH_ERRORS
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
static struct PyMethodDef default_methods[] = {
THP_FUNCTION_DEFAULT_METHODS,

View file

@ -51,19 +51,21 @@ PyObject* CppFunction_pynew(
(char*)"_set_sequence_nr", THPCppFunction_set_sequence_nr, METH_O, nullptr \
}
#define THP_FUNCTION_DEFAULT_PROPERTIES \
{(char*)"next_functions", \
THPCppFunction_next_functions, \
nullptr, \
nullptr, \
nullptr}, \
{(char*)"requires_grad", \
THPCppFunction_requires_grad, \
nullptr, \
nullptr, \
nullptr}, \
{ \
(char*)"metadata", THPCppFunction_metadata, nullptr, nullptr, nullptr \
#define THP_FUNCTION_DEFAULT_PROPERTIES \
{(char*)"next_functions", \
THPCppFunction_next_functions, \
nullptr, \
nullptr, \
nullptr}, \
{(char*)"requires_grad", \
THPCppFunction_requires_grad, \
nullptr, \
nullptr, \
nullptr}, \
{(char*)"metadata", THPCppFunction_metadata, nullptr, nullptr, nullptr}, \
{ \
(char*)"_input_metadata", THPCppFunction_input_metadata, nullptr, nullptr, \
nullptr \
}
PyObject* THPCppFunction_next_functions(PyObject* self, void* _unused);
@ -75,6 +77,7 @@ PyObject* THPCppFunction_register_prehook(PyObject* self, PyObject* hook);
PyObject* THPCppFunction_name(PyObject* self, PyObject* noargs);
PyObject* THPCppFunction_sequence_nr(PyObject* self, PyObject* noargs);
PyObject* THPCppFunction_input_metadata(PyObject* self, void* _unused);
PyTypeObject* _initFunctionPyTypeObject(
PyTypeObject& type,

View file

@ -166,6 +166,24 @@ c10::intrusive_ptr<at::ivalue::Future> PythonEngine::execute_with_graph_task(
PyObject* THPEngineClass = nullptr;
inline static Edge parseGradientEdge(PyObject* obj, int64_t index) {
PyObject* grad_fn = PyTuple_GetItem(obj, 0);
auto output_nr = THPUtils_unpackLong(PyTuple_GetItem(obj, 1));
std::shared_ptr<torch::autograd::Node> grad_fn_sp;
if (THPFunction_Check(grad_fn)) {
grad_fn_sp = ((THPFunction*)grad_fn)->cdata.lock();
} else if (THPCppFunction_Check(grad_fn)) {
grad_fn_sp = ((THPCppFunction*)grad_fn)->cdata;
} else {
TORCH_CHECK(
false,
"GradientEdge's first object must be an autograd.graph.Node "
"but got ",
THPUtils_typename(grad_fn));
}
return Edge(grad_fn_sp, output_nr);
}
// Implementation of torch._C._EngineBase.run_backward
PyObject* THPEngine_run_backward(
PyObject* self,
@ -239,22 +257,29 @@ PyObject* THPEngine_run_backward(
grads.reserve(num_tensors);
for (const auto i : c10::irange(num_tensors)) {
PyObject* _tensor = PyTuple_GET_ITEM(tensors, i);
TORCH_CHECK(
THPVariable_Check(_tensor),
"element ",
i,
" of tensors tuple is not a Tensor");
const auto& variable = THPVariable_Unpack(_tensor);
TORCH_CHECK(
!isBatchedTensor(variable),
"torch.autograd.grad(outputs, inputs, grad_outputs) called inside ",
"torch.vmap. We do not support the case where any outputs are ",
"vmapped tensors (output ",
i,
" is being vmapped over). Please "
"call autograd.grad() outside torch.vmap or file a bug report "
"with your use case.")
auto gradient_edge = torch::autograd::impl::gradient_edge(variable);
Edge gradient_edge; // Temporary variable to hold the gradient edge
c10::optional<at::Tensor> mb_output;
if (THPVariable_Check(_tensor)) {
mb_output = THPVariable_Unpack(_tensor);
TORCH_CHECK(
!isBatchedTensor(mb_output.value()),
"torch.autograd.grad(outputs, inputs, grad_outputs) called inside ",
"torch.vmap. We do not support the case where any outputs are ",
"vmapped tensors (output ",
i,
" is being vmapped over). Please "
"call autograd.grad() outside torch.vmap or file a bug report "
"with your use case.");
gradient_edge = torch::autograd::impl::gradient_edge(mb_output.value());
} else if (PyObject_IsInstance(_tensor, THPGradientEdgeClass)) {
gradient_edge = parseGradientEdge(_tensor, i);
} else {
TORCH_CHECK(
false,
"element ",
i,
" of tensors tuple is neither a Tensor nor a GradientEdge");
}
TORCH_CHECK(
gradient_edge.function,
"element ",
@ -281,7 +306,13 @@ PyObject* THPEngine_run_backward(
i,
" of gradients tuple is not a Tensor or None");
TORCH_CHECK(
!variable.requires_grad(),
mb_output.has_value(),
"element ",
i,
" of gradients tuple is None, but the corresponding output is a GradientEdge."
"This is not supported.");
TORCH_CHECK(
!mb_output.value().requires_grad(),
"element ",
i,
" of gradients tuple is None, but the corresponding Tensor requires grad");
@ -330,23 +361,7 @@ PyObject* THPEngine_run_backward(
output_edges.emplace_back(grad_fn, output_nr);
}
} else if (PyObject_IsInstance(input, THPGradientEdgeClass)) {
auto node = PyTuple_GetItem(input, 0);
bool isTHPFunction = THPFunction_Check(node);
bool isTHPCppFunction = THPCppFunction_Check(node);
TORCH_CHECK(
isTHPFunction || isTHPCppFunction,
"GradientEdge first object must be an autograd.graph.Node "
"but got ",
THPUtils_typename(node));
std::shared_ptr<torch::autograd::Node> node_sp;
if (isTHPFunction) {
node_sp = ((THPFunction*)node)->cdata.lock();
} else {
node_sp = ((torch::autograd::THPCppFunction*)node)->cdata;
}
auto output_nr = THPUtils_unpackUInt32(PyTuple_GetItem(input, 1));
output_edges.emplace_back(node_sp, output_nr);
output_edges.emplace_back(parseGradientEdge(input, i));
} else {
TORCH_CHECK(
false,

View file

@ -1181,6 +1181,26 @@ PyObject* THPFunction_set_sequence_nr(PyObject* self, PyObject* sequence_nr) {
END_HANDLE_TH_ERRORS
}
PyObject* THPFunction_input_metadata(PyObject* self, void* unused) {
HANDLE_TH_ERRORS;
auto cdata = ((THPFunction*)self)->cdata.lock();
const auto num_inputs = cdata->num_inputs();
THPObjectPtr list(PyTuple_New(num_inputs));
if (!list) {
return nullptr;
}
for (size_t i = 0; i < num_inputs; ++i) {
const auto& metadata = cdata->input_metadata(i);
THPObjectPtr item(py::cast(metadata).release().ptr());
if (!item) {
return nullptr;
}
PyTuple_SET_ITEM(list.get(), i, item.release());
}
return list.release();
END_HANDLE_TH_ERRORS
}
PyObject* THPFunction_maybe_clear_saved_tensors(
PyObject* self,
PyObject* noargs) {
@ -1723,6 +1743,11 @@ static struct PyGetSetDef THPFunction_properties[] = {
nullptr},
{"requires_grad", getRequiresGrad, nullptr, nullptr, nullptr},
{"metadata", (getter)THPFunction_metadata, nullptr, nullptr, nullptr},
{"_input_metadata",
(getter)THPFunction_input_metadata,
nullptr,
nullptr,
nullptr},
{"materialize_grads",
nullptr,
(setter)THPFunction_set_materialize_grads,