From 2eec02523b1c2fa6c818530178deb5f30a8ba26d Mon Sep 17 00:00:00 2001 From: soulitzer Date: Tue, 16 Jul 2024 09:29:08 -0400 Subject: [PATCH] [autograd] Support GradientEdge as output for torch.autograd.grad (#127766) This is useful for splitting grad to run in two parts while preserving intermediates:
Click to see code ```python import collections import weakref from torch.autograd.graph import GradientEdge def _get_grad_fn_or_grad_acc(t): if t.requires_grad and t.grad_fn is None: return t.view_as(t).grad_fn.next_functions[0][0] else: return t.grad_fn def reverse_closure(roots, target_nodes): # Recurse until we reach a target node closure = set() actual_target_nodes = set() q: Deque = collections.deque() for node in roots: if node is not None and node not in closure: closure.add(node) q.append(node) while q: node = q.popleft() reverse_edges = node.metadata.get("reverse_edges", []) for holder_ref, idx in reverse_edges: ref = holder_ref() if ref is not None: raise RuntimeError("Reverse graph is no longer alive") fn = ref.node if fn in closure or fn is None: continue if fn in target_nodes: actual_target_nodes.add(fn) continue closure.add(fn) q.append(fn) return closure, actual_target_nodes # Enable weak pointer class Holder(): def __init__(self, node): self.node = node # TODO: use weak references to avoid reference cycle def construct_reverse_graph(roots): q: Deque = collections.deque() root_seen = set() reverse_graph_refs = [] for node in roots: if node is not None and node not in root_seen: q.append(node) root_seen.add(node) while q: node = q.popleft() for fn, idx in node.next_functions: if fn is not None: # Don't necessarily need to store on the graph reverse_edges = fn.metadata.get("reverse_edges", []) if len(reverse_edges) == 0: q.append(fn) holder = Holder(node) holder_ref = weakref.ref(holder) reverse_graph_refs.append(holder) reverse_edges.append((holder_ref, idx)) fn.metadata["reverse_edges"] = reverse_edges return reverse_graph_refs def get_param_groups(inputs, params): inputs_closure, _ = reverse_closure(inputs, set()) param_groups = dict() # keyed on intermediates for i, param in enumerate(params): closure, intersected = reverse_closure([param], inputs_closure) param_group = { "params": set([param]), "intermediates": set(intersected), } for input_node in intersected: existing = param_groups.get(input_node, None) if existing is not None: existing["params"] = existing["params"].union(param_group["params"]) existing["intermediates"] = existing["intermediates"].union(param_group["intermediates"]) param_group = existing else: param_groups[input_node] = param_group # Sanity check: union of all param_groups params should be equal to all params union_params = set() seen_ids = set() unique_param_groups = [] for param_group in param_groups.values(): if id(param_group) not in seen_ids: seen_ids.add(id(param_group)) unique_param_groups.append(param_group) union_params = union_params.union(param_group["params"]) assert union_params == set(params) return unique_param_groups def compute_grads_only_inputs2(roots, inps, weights): root_grad_fns = list(map(_get_grad_fn_or_grad_acc, roots)) inp_grad_fns = list(map(_get_grad_fn_or_grad_acc, inps)) weight_grad_fns = list(map(_get_grad_fn_or_grad_acc, weights)) reverse_graph_refs = construct_reverse_graph(root_grad_fns) param_groups = get_param_groups(inp_grad_fns, weight_grad_fns) del reverse_graph_refs for param_group in param_groups: for i, intermediate in enumerate(param_group["intermediates"]): def get_hook(param_group, i): def hook(grad_inputs): if param_group.get("grads", None) is None: param_group["grads"] = [None] * len(param_group["intermediates"]) param_group["grads"][i] = grad_inputs return hook # These are always "split" nodes that we need to recompute, so # save their inputs. intermediate.register_prehook(get_hook(param_group, i)) dinputs = torch.autograd.grad((out,), inputs=tuple(inps), grad_outputs=(torch.ones_like(out),), retain_graph=True) return dinputs, param_groups def compute_grads_only_weights2(user_weights, param_groups): all_dweights = dict() for param_group in param_groups: # TODO: Handle case where intermediate can have multiple outputs intermediate_edges = tuple(GradientEdge(i, 0) for i in param_group["intermediates"]) weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"]) assert all(len(g) == 1 for g in param_group["grads"]) # [NEW!] Able to pass a GradientEdge to autograd.grad as output # We do not need to retain_graph because... guarantee no overlap? print("trying to execute: ", intermediate_edges, weights_edges) dweights = torch.autograd.grad(intermediate_edges, weights_edges, grad_outputs=sum(param_group["grads"], tuple())) for w, dw in zip(param_group["params"], dweights): all_dweights[w] = dw # return grads in the original order weights were provided in out = [] for w in user_weights: grad_acc = _get_grad_fn_or_grad_acc(w) out.append(all_dweights[grad_acc]) return tuple(out) ```
```python import torch.nn as nn # Setup mod1 = nn.Linear(10, 10) mod2 = nn.Linear(10, 10) a = torch.rand(10, requires_grad=True) weights = tuple(mod1.parameters()) + tuple(mod2.parameters()) inps = (a,) out = mod2(mod1(a)) class LoggingTensorMode(torch.utils._python_dispatch.TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} rs = func(*args, **kwargs) print(f"{func.__module__}.{func.__name__}") return rs print(" -- SPLIT -- ") # Compute gradients in two parts with LoggingTensorMode(): print("PART 1") dinputs, state = compute_grads_only_inputs2((out,), inps, weights) print("PART 2") dweights = compute_grads_only_weights2(weights, state) out = mod2(mod1(a)) print(" -- REF -- ") # Compare with reference with LoggingTensorMode(): ref_all_gradients = torch.autograd.grad(out, inputs=tuple(inps) + weights, grad_outputs=(torch.ones_like(out),)) for actual, ref in zip(dinputs + dweights, ref_all_gradients): print(torch.allclose(actual, ref)) ``` image ``` PART 1 torch._ops.aten.view.default torch._ops.aten.view.default torch._ops.aten.view.default torch._ops.aten.view.default torch._ops.aten.view.default torch._ops.aten.ones_like.default V0603 10:17:21.590878 8300067520 torch/autograd/graph.py:751] Executing: with grad_outputs: [f32[10]] torch._ops.aten.view.default V0603 10:17:21.591204 8300067520 torch/autograd/graph.py:751] Executing: with grad_outputs: [f32[1, 10]] torch._ops.aten.t.default torch._ops.aten.mm.default V0603 10:17:21.591578 8300067520 torch/autograd/graph.py:751] Executing: with grad_outputs: [f32[1, 10]] torch._ops.aten.view.default V0603 10:17:21.591747 8300067520 torch/autograd/graph.py:751] Executing: with grad_outputs: [f32[10]] torch._ops.aten.view.default V0603 10:17:21.591834 8300067520 torch/autograd/graph.py:751] Executing: with grad_outputs: [f32[1, 10]] torch._ops.aten.t.default torch._ops.aten.mm.default V0603 10:17:21.591922 8300067520 torch/autograd/graph.py:751] Executing: with grad_outputs: [f32[1, 10]] torch._ops.aten.view.default PART 2 trying to execute: (GradientEdge(node=, output_nr=0),) (GradientEdge(node=, output_nr=0), GradientEdge(node=, output_nr=0)) V0603 10:17:21.592223 8300067520 torch/autograd/graph.py:751] Executing: with grad_outputs: [f32[1, 10]] torch._ops.aten.t.default torch._ops.aten.mm.default torch._ops.aten.t.default torch._ops.aten.sum.dim_IntList torch._ops.aten.view.default V0603 10:17:21.592421 8300067520 torch/autograd/graph.py:751] Executing: with grad_outputs: [f32[10, 10]] torch._ops.aten.t.default trying to execute: (GradientEdge(node=, output_nr=0),) (GradientEdge(node=, output_nr=0), GradientEdge(node=, output_nr=0)) V0603 10:17:21.593481 8300067520 torch/autograd/graph.py:751] Executing: with grad_outputs: [f32[1, 10]] torch._ops.aten.t.default torch._ops.aten.mm.default torch._ops.aten.t.default torch._ops.aten.sum.dim_IntList torch._ops.aten.view.default V0603 10:17:21.593750 8300067520 torch/autograd/graph.py:751] Executing: with grad_outputs: [f32[10, 10]] torch._ops.aten.t.default torch._ops.aten.view.default torch._ops.aten.view.default torch._ops.aten.view.default torch._ops.aten.view.default ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127766 Approved by: https://github.com/albanD --- test/inductor/test_compiled_autograd.py | 1 + test/profiler/test_profiler_tree.py | 26 +++++ test/test_autograd.py | 107 +++++++++++++++++++ torch/autograd/__init__.py | 111 +++++++++++++++----- torch/autograd/graph.py | 13 ++- torch/csrc/autograd/init.cpp | 18 ++++ torch/csrc/autograd/python_cpp_function.cpp | 21 ++++ torch/csrc/autograd/python_cpp_function.h | 29 ++--- torch/csrc/autograd/python_engine.cpp | 83 +++++++++------ torch/csrc/autograd/python_function.cpp | 25 +++++ 10 files changed, 359 insertions(+), 75 deletions(-) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 0350f68bffb..ef0651f9921 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -2277,6 +2277,7 @@ known_failing_tests = { "test_grad_materialize_grads", # RuntimeError: compiled_autograd does not support create_graph "test_grad_nonleaf", # RuntimeError: compiled_autograd does not support create_graph "test_grad_nonleaf_many_outputs", # RuntimeError: compiled_autograd does not support create_graph + "test_gradient_edge_output", # RuntimeError: trying to backward through the graph a second time "test_hessian_vector", # RuntimeError: compiled_autograd does not support create_graph "test_hook_closure_cycle_use_custom_function_True_use_tensor_hook_False", # AttributeError: type object "test_hook_closure_cycle_use_custom_function_True_use_tensor_hook_True", # AttributeError: type object diff --git a/test/profiler/test_profiler_tree.py b/test/profiler/test_profiler_tree.py index 6d64f1175fa..580c9ea12d9 100644 --- a/test/profiler/test_profiler_tree.py +++ b/test/profiler/test_profiler_tree.py @@ -436,6 +436,7 @@ class TestProfilerTree(TestCase): [memory]""", ) + @unittest.skip("https://github.com/pytorch/pytorch/issues/83606") @unittest.skipIf( TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite." ) @@ -478,8 +479,19 @@ class TestProfilerTree(TestCase): torch/autograd/__init__.py(...): _tensor_or_tensors_to_tuple torch/autograd/__init__.py(...): _make_grads + typing.py(...): inner + typing.py(...): __hash__ + + typing.py(...): cast + + + + + + + aten::ones_like aten::empty_like @@ -910,6 +922,9 @@ class TestProfilerTree(TestCase): @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") @ProfilerTree.test def test_profiler_experimental_tree_cuda_detailed(self): + # Do lazy imports ahead of time to avoid it showing up in the tree + import torch.nested._internal.nested_tensor + model = torch.nn.modules.Linear(1, 1, device="cuda") model.train() opt = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) @@ -967,8 +982,19 @@ class TestProfilerTree(TestCase): torch/autograd/__init__.py(...): _tensor_or_tensors_to_tuple torch/autograd/__init__.py(...): _make_grads + typing.py(...): inner + typing.py(...): __hash__ + + typing.py(...): cast + + + + + + + aten::ones_like aten::empty_like diff --git a/test/test_autograd.py b/test/test_autograd.py index 4f7c095841f..0b00a367fcd 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -985,6 +985,113 @@ class TestAutograd(TestCase): torch.autograd.backward(out.sum(), inputs=(x, edge_y)) torch.autograd.backward(out.sum(), inputs=(edge_x, edge_y)) + def test_grad_fn_input_metadata(self): + x = torch.rand(2, requires_grad=True, dtype=torch.float32) + y = torch.rand(2, requires_grad=True, dtype=torch.float32) + z = x * y + z_metadata = z.grad_fn._input_metadata[0] + self.assertEqual(z_metadata.shape, (2,)) + self.assertEqual(z_metadata.dtype, torch.float32) + + # Multiple outputs + b = torch.rand(3, 3, requires_grad=True) + var, _ = torch.var_mean(b, dim=0) + + metadata_0 = var.grad_fn._input_metadata[0] + metadata_1 = var.grad_fn._input_metadata[1] + self.assertEqual(metadata_0.shape, (3,)) + self.assertEqual(metadata_1.shape, (3,)) + + # Preserves symints + nt = torch.nested.nested_tensor( + [ + torch.randn( + 3, + 2, + ), + torch.randn( + 2, + 2, + ), + ], + layout=torch.jagged, + requires_grad=True, + ) + nt_metadata = nt.clone().grad_fn._input_metadata[0] + + self.assertIsInstance(nt_metadata.shape[1], torch.SymInt) + self.assertEqual(nt_metadata.shape, nt.shape) + self.assertTrue(nt_metadata.is_nested_tensor) + self.assertFalse(nt_metadata.is_cpp_nested_tensor) + self.assertEqual(nt_metadata.dtype, nt.dtype) + + class Test(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + x = torch.randn(3, 3, requires_grad=True) + x = Test.apply(x) + metadata = x.grad_fn._input_metadata[0] + self.assertEqual(metadata.shape, (3, 3)) + + def test_gradient_edge_output(self): + x = torch.tensor([1.0, 2.0], requires_grad=True) + + def fn(x, reduce=True): + tmp = x.sin().cos() + if reduce: + tmp = tmp.sum() + out = tmp.exp().clone().sin().sum() + tmp_edge = torch.autograd.graph.get_gradient_edge(tmp) + return out, tmp_edge + + # Compute fn backward in two steps + out, tmp_edge = fn(x) + (tmp_grad,) = torch.autograd.grad(out, (tmp_edge,)) + + (x_grad,) = torch.autograd.grad(tmp_edge, (x,), grad_outputs=(tmp_grad,)) + + # Compare with as if we did it in one go. + out, _ = fn(x) + (x_grad_ref,) = torch.autograd.grad(out, (x,)) + self.assertEqual(x_grad, x_grad_ref) + + # Incorrect case: grad_outputs not passed/implicitly None and output is + # not a scalar + out, tmp_edge = fn(x, reduce=False) + with self.assertRaisesRegex( + RuntimeError, + "grad can be implicitly created only for scalar output", + ): + torch.autograd.grad(tmp_edge, (x,)) + + # grad_outputs is None, and output is a scalar is fine + out, tmp_edge = fn(x, reduce=True) + torch.autograd.grad(tmp_edge, (x,)) + + # Incorrect case: grad_outputs wrong size + out, tmp_edge = fn(x) + (tmp_grad,) = torch.autograd.grad(out, (tmp_edge,)) + with self.assertRaisesRegex(RuntimeError, "Mismatch in shape"): + torch.autograd.grad( + tmp_edge, (x,), grad_outputs=torch.tensor([1.0, 2.0, 3.0, 4.0]) + ) + + # Incorrect case: wrong dtype + out, tmp_edge = fn(x) + (tmp_grad,) = torch.autograd.grad(out, (tmp_edge,)) + with self.assertRaisesRegex(RuntimeError, "required to have the same dtype"): + torch.autograd.grad( + tmp_edge, + (x,), + grad_outputs=torch.rand_like(tmp_grad, dtype=torch.complex64), + ) + def test_grad_nonleaf(self): x_init = torch.randn(2, 2, requires_grad=True) x = x_init diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py index 5c68b94b001..a11491ac832 100644 --- a/torch/autograd/__init__.py +++ b/torch/autograd/__init__.py @@ -56,12 +56,21 @@ _ShapeorNestedShape = Union[_size, Sequence[_size], torch.Tensor] def _calculate_shape( - output: torch.Tensor, grad: torch.Tensor, is_grads_batched: bool + output: Union[torch.Tensor, graph.GradientEdge], + grad: torch.Tensor, + is_grads_batched: bool, ) -> Tuple[_ShapeorNestedShape, _ShapeorNestedShape]: # is_same_size ensures that both tensors are either nested or non nested # circular import from torch.nested._internal.nested_tensor import NestedTensor + if isinstance(output, graph.GradientEdge): + # We have already checked that we are not a C++ NestedTensor + if is_grads_batched: + raise RuntimeError("Batched grads are not supported with GradientEdge") + out_metadata = output.node._input_metadata[output.output_nr] + return torch.Size(out_metadata.shape), grad.shape + if output.is_nested and not isinstance(output, NestedTensor): if is_grads_batched: raise RuntimeError("Batched grads are not supported with Nested Tensor.") @@ -76,27 +85,58 @@ def _calculate_shape( def _make_grads( - outputs: Sequence[torch.Tensor], + outputs: Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]], grads: Sequence[_OptionalTensor], is_grads_batched: bool, ) -> Tuple[_OptionalTensor, ...]: new_grads: List[_OptionalTensor] = [] for out, grad in zip(outputs, grads): + out = cast(Union[torch.Tensor, graph.GradientEdge], out) + out_size = None + out_device = None + + if isinstance(out, graph.GradientEdge): + out_metadata = out.node._input_metadata[out.output_nr] + out_size = torch.Size(out_metadata.shape) + out_dtype = out_metadata.dtype + out_device = out_metadata.device + out_is_nested = out_metadata.is_nested_tensor + if out_metadata.is_cpp_nested_tensor: + raise RuntimeError( + "C++ NestedTensor are not supported with GradientEdge" + ) + out_is_cpp_nested = False + else: + # circular import + from torch.nested._internal.nested_tensor import NestedTensor + + assert isinstance(out, torch.Tensor) + out_dtype = out.dtype + out_is_nested = out.is_nested + out_is_cpp_nested = out_is_nested and not isinstance(out, NestedTensor) + if not out_is_cpp_nested: + out_size = out.shape + if isinstance(grad, torch.Tensor): from torch.fx.experimental.symbolic_shapes import expect_true, sym_eq first_grad = grad if not is_grads_batched else grad[0] + # TODO: We can remove this conditional once we uniformly use # singleton int to represent jagged dimension, so that size() call - # on nested tensor works - if out.is_nested or first_grad.is_nested: + # on nested tensor works. + if out_is_cpp_nested: + assert isinstance(out, torch.Tensor) shape_matches = torch.is_same_size(out, first_grad) else: # We need to do a regular size check, without going through # the operator, to be able to handle unbacked symints # (expect_true ensures we can deal with unbacked) - shape_matches = expect_true(sym_eq(out.size(), first_grad.size())) + assert out_size is not None + shape_matches = expect_true(sym_eq(out_size, first_grad.size())) + if not shape_matches: + out = cast(Union[torch.Tensor, graph.GradientEdge], out) out_shape, grad_shape = _calculate_shape( out, first_grad, is_grads_batched ) @@ -130,7 +170,7 @@ def _make_grads( + str(out_shape) + "." ) - if out.dtype.is_complex != grad.dtype.is_complex: + if out_dtype.is_complex != grad.dtype.is_complex: raise RuntimeError( "For complex Tensors, both grad_output and output" " are required to have the same dtype." @@ -141,25 +181,43 @@ def _make_grads( + " and output[" + str(outputs.index(out)) + "] has a dtype of " - + str(out.dtype) + + str(out_dtype) + "." ) new_grads.append(grad) elif grad is None: - if out.requires_grad: - if out.numel() != 1: + if isinstance(out, graph.GradientEdge) or out.requires_grad: # type: ignore[attr-defined] + if isinstance(out, graph.GradientEdge): + assert out_size is not None + out_numel_is_1 = all(o == 1 for o in out_size) + else: + assert isinstance(out, torch.Tensor) + out_numel_is_1 = out.numel() == 1 + if not out_numel_is_1: raise RuntimeError( "grad can be implicitly created only for scalar outputs" ) - if not out.dtype.is_floating_point: + if not out_dtype.is_floating_point: msg = ( "grad can be implicitly created only for real scalar outputs" - f" but got {out.dtype}" + f" but got {out_dtype}" ) raise RuntimeError(msg) - new_grads.append( - torch.ones_like(out, memory_format=torch.preserve_format) - ) + if isinstance(out, graph.GradientEdge): + assert out_size is not None + assert out_device is not None + new_grads.append( + torch.ones( + out_size, + dtype=out_dtype, + device=out_device, + ) + ) + else: + assert isinstance(out, torch.Tensor) + new_grads.append( + torch.ones_like(out, memory_format=torch.preserve_format) + ) else: new_grads.append(None) else: @@ -297,7 +355,7 @@ def backward( def grad( - outputs: _TensorOrTensors, + outputs: _TensorOrTensorsOrGradEdge, inputs: _TensorOrTensorsOrGradEdge, grad_outputs: Optional[_TensorOrTensors] = None, retain_graph: Optional[bool] = None, @@ -327,7 +385,7 @@ def grad( ``torch.autograd.backward``. Args: - outputs (sequence of Tensor): outputs of the differentiated function. + outputs (sequence of Tensor or GradientEdge): outputs of the differentiated function. inputs (sequence of Tensor or GradientEdge): Inputs w.r.t. which the gradient will be returned (and not accumulated into ``.grad``). grad_outputs (sequence of Tensor): The "vector" in the vector-Jacobian product. @@ -369,21 +427,24 @@ def grad( ) if allow_unused is None: allow_unused = materialize_grads - t_outputs = cast( - Tuple[torch.Tensor, ...], - (outputs,) if is_tensor_like(outputs) else tuple(outputs), - ) + if is_tensor_like(outputs) or isinstance(outputs, graph.GradientEdge): + outputs = cast( + Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]], (outputs,) + ) + else: + outputs = tuple(outputs) if is_tensor_like(inputs) or isinstance(inputs, graph.GradientEdge): inputs = cast(_TensorOrTensorsOrGradEdge, (inputs,)) else: inputs = tuple(inputs) + t_outputs = tuple(i for i in outputs if is_tensor_like(i)) t_inputs = tuple(i for i in inputs if is_tensor_like(i)) overridable_args = t_outputs + t_inputs if has_torch_function(overridable_args): return handle_torch_function( grad, overridable_args, - t_outputs, + outputs, inputs, grad_outputs=grad_outputs, retain_graph=retain_graph, @@ -403,9 +464,9 @@ def grad( stacklevel=2, ) - grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(t_outputs)) + grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(outputs)) grad_outputs_ = _make_grads( - t_outputs, grad_outputs_, is_grads_batched=is_grads_batched + outputs, grad_outputs_, is_grads_batched=is_grads_batched ) if retain_graph is None: @@ -418,7 +479,7 @@ def grad( def vjp(gO): return _engine_run_backward( - t_outputs, + outputs, gO, retain_graph, create_graph, @@ -432,7 +493,7 @@ def grad( ) else: result = _engine_run_backward( - t_outputs, + outputs, grad_outputs_, retain_graph, create_graph, diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index 5482120da43..7f9a38aaa08 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -79,6 +79,11 @@ class Node(abc.ABC): r"""Return the metadata.""" raise NotImplementedError + @property + @abc.abstractmethod + def _input_metadata(self) -> List[Any]: + raise NotImplementedError + @abc.abstractmethod def _register_hook_dict(self, tensor: torch.Tensor) -> None: raise NotImplementedError @@ -170,7 +175,9 @@ class Node(abc.ABC): return NotImplemented -def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Node: +def _get_grad_fn_or_grad_acc(t: Union[torch.Tensor, "GradientEdge"]) -> Node: + if isinstance(t, GradientEdge): + return t.node if t.requires_grad and t.grad_fn is None: node = t.view_as(t).grad_fn.next_functions[0][0] # type: ignore[union-attr] else: @@ -738,7 +745,7 @@ def allow_mutation_on_saved_tensors() -> ( def _register_logging_hooks_on_whole_graph( - t_outputs: Sequence[torch.Tensor], + t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], ) -> Callable[[], None]: grad_fns = list(map(_get_grad_fn_or_grad_acc, t_outputs)) @@ -788,7 +795,7 @@ def _register_logging_hooks_on_whole_graph( def _engine_run_backward( - t_outputs: Sequence[torch.Tensor], + t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], *args: Any, **kwargs: Any, ) -> Tuple[torch.Tensor, ...]: diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index df731a64d8b..8a0c75cd781 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -184,6 +185,23 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { .value("NO_GRAD_MODE", CreationMeta::NO_GRAD_MODE) .value("INFERENCE_MODE", CreationMeta::INFERENCE_MODE); + py::class_(m, "_InputMetadata") + .def_property_readonly( + "dtype", + [](const torch::autograd::InputMetadata& m) { + PyObject* raw_obj = + (PyObject*)torch::getTHPDtype(m.dtype().toScalarType()); + return py::reinterpret_borrow(raw_obj); + }) + .def_property_readonly("device", &torch::autograd::InputMetadata::device) + .def_property_readonly( + "shape", &torch::autograd::InputMetadata::shape_as_dim_vector) + .def_property_readonly( + "is_nested_tensor", &torch::autograd::InputMetadata::is_nested_tensor) + .def_property_readonly( + "is_cpp_nested_tensor", + &torch::autograd::InputMetadata::is_cpp_nested_tensor); + py::class_(m, "_KinetoEvent") // name of the event .def("name", [](const KinetoEvent& e) { return e.name(); }) diff --git a/torch/csrc/autograd/python_cpp_function.cpp b/torch/csrc/autograd/python_cpp_function.cpp index ec057f91df4..e17ff88cce1 100644 --- a/torch/csrc/autograd/python_cpp_function.cpp +++ b/torch/csrc/autograd/python_cpp_function.cpp @@ -210,6 +210,27 @@ PyObject* THPCppFunction_set_sequence_nr( END_HANDLE_TH_ERRORS } +PyObject* THPCppFunction_input_metadata(PyObject* self, void* closure) { + HANDLE_TH_ERRORS; + auto& fn = *((THPCppFunction*)self)->cdata; + const auto num_inputs = + fn.num_inputs(); // Assuming there's a method to get the number of inputs + THPObjectPtr list(PyTuple_New(num_inputs)); + if (!list) { + return nullptr; + } + for (size_t i = 0; i < num_inputs; ++i) { + const auto& metadata = fn.input_metadata(i); + THPObjectPtr item(py::cast(metadata).release().ptr()); + if (!item) { + return nullptr; + } + PyTuple_SET_ITEM(list.get(), i, item.release()); + } + return list.release(); + END_HANDLE_TH_ERRORS +} + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) static struct PyMethodDef default_methods[] = { THP_FUNCTION_DEFAULT_METHODS, diff --git a/torch/csrc/autograd/python_cpp_function.h b/torch/csrc/autograd/python_cpp_function.h index bd81a28334b..832ab1c7677 100644 --- a/torch/csrc/autograd/python_cpp_function.h +++ b/torch/csrc/autograd/python_cpp_function.h @@ -51,19 +51,21 @@ PyObject* CppFunction_pynew( (char*)"_set_sequence_nr", THPCppFunction_set_sequence_nr, METH_O, nullptr \ } -#define THP_FUNCTION_DEFAULT_PROPERTIES \ - {(char*)"next_functions", \ - THPCppFunction_next_functions, \ - nullptr, \ - nullptr, \ - nullptr}, \ - {(char*)"requires_grad", \ - THPCppFunction_requires_grad, \ - nullptr, \ - nullptr, \ - nullptr}, \ - { \ - (char*)"metadata", THPCppFunction_metadata, nullptr, nullptr, nullptr \ +#define THP_FUNCTION_DEFAULT_PROPERTIES \ + {(char*)"next_functions", \ + THPCppFunction_next_functions, \ + nullptr, \ + nullptr, \ + nullptr}, \ + {(char*)"requires_grad", \ + THPCppFunction_requires_grad, \ + nullptr, \ + nullptr, \ + nullptr}, \ + {(char*)"metadata", THPCppFunction_metadata, nullptr, nullptr, nullptr}, \ + { \ + (char*)"_input_metadata", THPCppFunction_input_metadata, nullptr, nullptr, \ + nullptr \ } PyObject* THPCppFunction_next_functions(PyObject* self, void* _unused); @@ -75,6 +77,7 @@ PyObject* THPCppFunction_register_prehook(PyObject* self, PyObject* hook); PyObject* THPCppFunction_name(PyObject* self, PyObject* noargs); PyObject* THPCppFunction_sequence_nr(PyObject* self, PyObject* noargs); +PyObject* THPCppFunction_input_metadata(PyObject* self, void* _unused); PyTypeObject* _initFunctionPyTypeObject( PyTypeObject& type, diff --git a/torch/csrc/autograd/python_engine.cpp b/torch/csrc/autograd/python_engine.cpp index 0d7265370c4..38646701ebd 100644 --- a/torch/csrc/autograd/python_engine.cpp +++ b/torch/csrc/autograd/python_engine.cpp @@ -166,6 +166,24 @@ c10::intrusive_ptr PythonEngine::execute_with_graph_task( PyObject* THPEngineClass = nullptr; +inline static Edge parseGradientEdge(PyObject* obj, int64_t index) { + PyObject* grad_fn = PyTuple_GetItem(obj, 0); + auto output_nr = THPUtils_unpackLong(PyTuple_GetItem(obj, 1)); + std::shared_ptr grad_fn_sp; + if (THPFunction_Check(grad_fn)) { + grad_fn_sp = ((THPFunction*)grad_fn)->cdata.lock(); + } else if (THPCppFunction_Check(grad_fn)) { + grad_fn_sp = ((THPCppFunction*)grad_fn)->cdata; + } else { + TORCH_CHECK( + false, + "GradientEdge's first object must be an autograd.graph.Node " + "but got ", + THPUtils_typename(grad_fn)); + } + return Edge(grad_fn_sp, output_nr); +} + // Implementation of torch._C._EngineBase.run_backward PyObject* THPEngine_run_backward( PyObject* self, @@ -239,22 +257,29 @@ PyObject* THPEngine_run_backward( grads.reserve(num_tensors); for (const auto i : c10::irange(num_tensors)) { PyObject* _tensor = PyTuple_GET_ITEM(tensors, i); - TORCH_CHECK( - THPVariable_Check(_tensor), - "element ", - i, - " of tensors tuple is not a Tensor"); - const auto& variable = THPVariable_Unpack(_tensor); - TORCH_CHECK( - !isBatchedTensor(variable), - "torch.autograd.grad(outputs, inputs, grad_outputs) called inside ", - "torch.vmap. We do not support the case where any outputs are ", - "vmapped tensors (output ", - i, - " is being vmapped over). Please " - "call autograd.grad() outside torch.vmap or file a bug report " - "with your use case.") - auto gradient_edge = torch::autograd::impl::gradient_edge(variable); + Edge gradient_edge; // Temporary variable to hold the gradient edge + c10::optional mb_output; + if (THPVariable_Check(_tensor)) { + mb_output = THPVariable_Unpack(_tensor); + TORCH_CHECK( + !isBatchedTensor(mb_output.value()), + "torch.autograd.grad(outputs, inputs, grad_outputs) called inside ", + "torch.vmap. We do not support the case where any outputs are ", + "vmapped tensors (output ", + i, + " is being vmapped over). Please " + "call autograd.grad() outside torch.vmap or file a bug report " + "with your use case."); + gradient_edge = torch::autograd::impl::gradient_edge(mb_output.value()); + } else if (PyObject_IsInstance(_tensor, THPGradientEdgeClass)) { + gradient_edge = parseGradientEdge(_tensor, i); + } else { + TORCH_CHECK( + false, + "element ", + i, + " of tensors tuple is neither a Tensor nor a GradientEdge"); + } TORCH_CHECK( gradient_edge.function, "element ", @@ -281,7 +306,13 @@ PyObject* THPEngine_run_backward( i, " of gradients tuple is not a Tensor or None"); TORCH_CHECK( - !variable.requires_grad(), + mb_output.has_value(), + "element ", + i, + " of gradients tuple is None, but the corresponding output is a GradientEdge." + "This is not supported."); + TORCH_CHECK( + !mb_output.value().requires_grad(), "element ", i, " of gradients tuple is None, but the corresponding Tensor requires grad"); @@ -330,23 +361,7 @@ PyObject* THPEngine_run_backward( output_edges.emplace_back(grad_fn, output_nr); } } else if (PyObject_IsInstance(input, THPGradientEdgeClass)) { - auto node = PyTuple_GetItem(input, 0); - bool isTHPFunction = THPFunction_Check(node); - bool isTHPCppFunction = THPCppFunction_Check(node); - TORCH_CHECK( - isTHPFunction || isTHPCppFunction, - "GradientEdge first object must be an autograd.graph.Node " - "but got ", - THPUtils_typename(node)); - std::shared_ptr node_sp; - if (isTHPFunction) { - node_sp = ((THPFunction*)node)->cdata.lock(); - } else { - node_sp = ((torch::autograd::THPCppFunction*)node)->cdata; - } - - auto output_nr = THPUtils_unpackUInt32(PyTuple_GetItem(input, 1)); - output_edges.emplace_back(node_sp, output_nr); + output_edges.emplace_back(parseGradientEdge(input, i)); } else { TORCH_CHECK( false, diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index a5ba07b2cdb..03022511cc9 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -1181,6 +1181,26 @@ PyObject* THPFunction_set_sequence_nr(PyObject* self, PyObject* sequence_nr) { END_HANDLE_TH_ERRORS } +PyObject* THPFunction_input_metadata(PyObject* self, void* unused) { + HANDLE_TH_ERRORS; + auto cdata = ((THPFunction*)self)->cdata.lock(); + const auto num_inputs = cdata->num_inputs(); + THPObjectPtr list(PyTuple_New(num_inputs)); + if (!list) { + return nullptr; + } + for (size_t i = 0; i < num_inputs; ++i) { + const auto& metadata = cdata->input_metadata(i); + THPObjectPtr item(py::cast(metadata).release().ptr()); + if (!item) { + return nullptr; + } + PyTuple_SET_ITEM(list.get(), i, item.release()); + } + return list.release(); + END_HANDLE_TH_ERRORS +} + PyObject* THPFunction_maybe_clear_saved_tensors( PyObject* self, PyObject* noargs) { @@ -1723,6 +1743,11 @@ static struct PyGetSetDef THPFunction_properties[] = { nullptr}, {"requires_grad", getRequiresGrad, nullptr, nullptr, nullptr}, {"metadata", (getter)THPFunction_metadata, nullptr, nullptr, nullptr}, + {"_input_metadata", + (getter)THPFunction_input_metadata, + nullptr, + nullptr, + nullptr}, {"materialize_grads", nullptr, (setter)THPFunction_set_materialize_grads,