diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 0350f68bffb..ef0651f9921 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -2277,6 +2277,7 @@ known_failing_tests = { "test_grad_materialize_grads", # RuntimeError: compiled_autograd does not support create_graph "test_grad_nonleaf", # RuntimeError: compiled_autograd does not support create_graph "test_grad_nonleaf_many_outputs", # RuntimeError: compiled_autograd does not support create_graph + "test_gradient_edge_output", # RuntimeError: trying to backward through the graph a second time "test_hessian_vector", # RuntimeError: compiled_autograd does not support create_graph "test_hook_closure_cycle_use_custom_function_True_use_tensor_hook_False", # AttributeError: type object "test_hook_closure_cycle_use_custom_function_True_use_tensor_hook_True", # AttributeError: type object diff --git a/test/profiler/test_profiler_tree.py b/test/profiler/test_profiler_tree.py index 6d64f1175fa..580c9ea12d9 100644 --- a/test/profiler/test_profiler_tree.py +++ b/test/profiler/test_profiler_tree.py @@ -436,6 +436,7 @@ class TestProfilerTree(TestCase): [memory]""", ) + @unittest.skip("https://github.com/pytorch/pytorch/issues/83606") @unittest.skipIf( TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite." ) @@ -478,8 +479,19 @@ class TestProfilerTree(TestCase): torch/autograd/__init__.py(...): _tensor_or_tensors_to_tuple torch/autograd/__init__.py(...): _make_grads + typing.py(...): inner + typing.py(...): __hash__ + + typing.py(...): cast + + + + + + + aten::ones_like aten::empty_like @@ -910,6 +922,9 @@ class TestProfilerTree(TestCase): @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") @ProfilerTree.test def test_profiler_experimental_tree_cuda_detailed(self): + # Do lazy imports ahead of time to avoid it showing up in the tree + import torch.nested._internal.nested_tensor + model = torch.nn.modules.Linear(1, 1, device="cuda") model.train() opt = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) @@ -967,8 +982,19 @@ class TestProfilerTree(TestCase): torch/autograd/__init__.py(...): _tensor_or_tensors_to_tuple torch/autograd/__init__.py(...): _make_grads + typing.py(...): inner + typing.py(...): __hash__ + + typing.py(...): cast + + + + + + + aten::ones_like aten::empty_like diff --git a/test/test_autograd.py b/test/test_autograd.py index 4f7c095841f..0b00a367fcd 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -985,6 +985,113 @@ class TestAutograd(TestCase): torch.autograd.backward(out.sum(), inputs=(x, edge_y)) torch.autograd.backward(out.sum(), inputs=(edge_x, edge_y)) + def test_grad_fn_input_metadata(self): + x = torch.rand(2, requires_grad=True, dtype=torch.float32) + y = torch.rand(2, requires_grad=True, dtype=torch.float32) + z = x * y + z_metadata = z.grad_fn._input_metadata[0] + self.assertEqual(z_metadata.shape, (2,)) + self.assertEqual(z_metadata.dtype, torch.float32) + + # Multiple outputs + b = torch.rand(3, 3, requires_grad=True) + var, _ = torch.var_mean(b, dim=0) + + metadata_0 = var.grad_fn._input_metadata[0] + metadata_1 = var.grad_fn._input_metadata[1] + self.assertEqual(metadata_0.shape, (3,)) + self.assertEqual(metadata_1.shape, (3,)) + + # Preserves symints + nt = torch.nested.nested_tensor( + [ + torch.randn( + 3, + 2, + ), + torch.randn( + 2, + 2, + ), + ], + layout=torch.jagged, + requires_grad=True, + ) + nt_metadata = nt.clone().grad_fn._input_metadata[0] + + self.assertIsInstance(nt_metadata.shape[1], torch.SymInt) + self.assertEqual(nt_metadata.shape, nt.shape) + self.assertTrue(nt_metadata.is_nested_tensor) + self.assertFalse(nt_metadata.is_cpp_nested_tensor) + self.assertEqual(nt_metadata.dtype, nt.dtype) + + class Test(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + x = torch.randn(3, 3, requires_grad=True) + x = Test.apply(x) + metadata = x.grad_fn._input_metadata[0] + self.assertEqual(metadata.shape, (3, 3)) + + def test_gradient_edge_output(self): + x = torch.tensor([1.0, 2.0], requires_grad=True) + + def fn(x, reduce=True): + tmp = x.sin().cos() + if reduce: + tmp = tmp.sum() + out = tmp.exp().clone().sin().sum() + tmp_edge = torch.autograd.graph.get_gradient_edge(tmp) + return out, tmp_edge + + # Compute fn backward in two steps + out, tmp_edge = fn(x) + (tmp_grad,) = torch.autograd.grad(out, (tmp_edge,)) + + (x_grad,) = torch.autograd.grad(tmp_edge, (x,), grad_outputs=(tmp_grad,)) + + # Compare with as if we did it in one go. + out, _ = fn(x) + (x_grad_ref,) = torch.autograd.grad(out, (x,)) + self.assertEqual(x_grad, x_grad_ref) + + # Incorrect case: grad_outputs not passed/implicitly None and output is + # not a scalar + out, tmp_edge = fn(x, reduce=False) + with self.assertRaisesRegex( + RuntimeError, + "grad can be implicitly created only for scalar output", + ): + torch.autograd.grad(tmp_edge, (x,)) + + # grad_outputs is None, and output is a scalar is fine + out, tmp_edge = fn(x, reduce=True) + torch.autograd.grad(tmp_edge, (x,)) + + # Incorrect case: grad_outputs wrong size + out, tmp_edge = fn(x) + (tmp_grad,) = torch.autograd.grad(out, (tmp_edge,)) + with self.assertRaisesRegex(RuntimeError, "Mismatch in shape"): + torch.autograd.grad( + tmp_edge, (x,), grad_outputs=torch.tensor([1.0, 2.0, 3.0, 4.0]) + ) + + # Incorrect case: wrong dtype + out, tmp_edge = fn(x) + (tmp_grad,) = torch.autograd.grad(out, (tmp_edge,)) + with self.assertRaisesRegex(RuntimeError, "required to have the same dtype"): + torch.autograd.grad( + tmp_edge, + (x,), + grad_outputs=torch.rand_like(tmp_grad, dtype=torch.complex64), + ) + def test_grad_nonleaf(self): x_init = torch.randn(2, 2, requires_grad=True) x = x_init diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py index 5c68b94b001..a11491ac832 100644 --- a/torch/autograd/__init__.py +++ b/torch/autograd/__init__.py @@ -56,12 +56,21 @@ _ShapeorNestedShape = Union[_size, Sequence[_size], torch.Tensor] def _calculate_shape( - output: torch.Tensor, grad: torch.Tensor, is_grads_batched: bool + output: Union[torch.Tensor, graph.GradientEdge], + grad: torch.Tensor, + is_grads_batched: bool, ) -> Tuple[_ShapeorNestedShape, _ShapeorNestedShape]: # is_same_size ensures that both tensors are either nested or non nested # circular import from torch.nested._internal.nested_tensor import NestedTensor + if isinstance(output, graph.GradientEdge): + # We have already checked that we are not a C++ NestedTensor + if is_grads_batched: + raise RuntimeError("Batched grads are not supported with GradientEdge") + out_metadata = output.node._input_metadata[output.output_nr] + return torch.Size(out_metadata.shape), grad.shape + if output.is_nested and not isinstance(output, NestedTensor): if is_grads_batched: raise RuntimeError("Batched grads are not supported with Nested Tensor.") @@ -76,27 +85,58 @@ def _calculate_shape( def _make_grads( - outputs: Sequence[torch.Tensor], + outputs: Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]], grads: Sequence[_OptionalTensor], is_grads_batched: bool, ) -> Tuple[_OptionalTensor, ...]: new_grads: List[_OptionalTensor] = [] for out, grad in zip(outputs, grads): + out = cast(Union[torch.Tensor, graph.GradientEdge], out) + out_size = None + out_device = None + + if isinstance(out, graph.GradientEdge): + out_metadata = out.node._input_metadata[out.output_nr] + out_size = torch.Size(out_metadata.shape) + out_dtype = out_metadata.dtype + out_device = out_metadata.device + out_is_nested = out_metadata.is_nested_tensor + if out_metadata.is_cpp_nested_tensor: + raise RuntimeError( + "C++ NestedTensor are not supported with GradientEdge" + ) + out_is_cpp_nested = False + else: + # circular import + from torch.nested._internal.nested_tensor import NestedTensor + + assert isinstance(out, torch.Tensor) + out_dtype = out.dtype + out_is_nested = out.is_nested + out_is_cpp_nested = out_is_nested and not isinstance(out, NestedTensor) + if not out_is_cpp_nested: + out_size = out.shape + if isinstance(grad, torch.Tensor): from torch.fx.experimental.symbolic_shapes import expect_true, sym_eq first_grad = grad if not is_grads_batched else grad[0] + # TODO: We can remove this conditional once we uniformly use # singleton int to represent jagged dimension, so that size() call - # on nested tensor works - if out.is_nested or first_grad.is_nested: + # on nested tensor works. + if out_is_cpp_nested: + assert isinstance(out, torch.Tensor) shape_matches = torch.is_same_size(out, first_grad) else: # We need to do a regular size check, without going through # the operator, to be able to handle unbacked symints # (expect_true ensures we can deal with unbacked) - shape_matches = expect_true(sym_eq(out.size(), first_grad.size())) + assert out_size is not None + shape_matches = expect_true(sym_eq(out_size, first_grad.size())) + if not shape_matches: + out = cast(Union[torch.Tensor, graph.GradientEdge], out) out_shape, grad_shape = _calculate_shape( out, first_grad, is_grads_batched ) @@ -130,7 +170,7 @@ def _make_grads( + str(out_shape) + "." ) - if out.dtype.is_complex != grad.dtype.is_complex: + if out_dtype.is_complex != grad.dtype.is_complex: raise RuntimeError( "For complex Tensors, both grad_output and output" " are required to have the same dtype." @@ -141,25 +181,43 @@ def _make_grads( + " and output[" + str(outputs.index(out)) + "] has a dtype of " - + str(out.dtype) + + str(out_dtype) + "." ) new_grads.append(grad) elif grad is None: - if out.requires_grad: - if out.numel() != 1: + if isinstance(out, graph.GradientEdge) or out.requires_grad: # type: ignore[attr-defined] + if isinstance(out, graph.GradientEdge): + assert out_size is not None + out_numel_is_1 = all(o == 1 for o in out_size) + else: + assert isinstance(out, torch.Tensor) + out_numel_is_1 = out.numel() == 1 + if not out_numel_is_1: raise RuntimeError( "grad can be implicitly created only for scalar outputs" ) - if not out.dtype.is_floating_point: + if not out_dtype.is_floating_point: msg = ( "grad can be implicitly created only for real scalar outputs" - f" but got {out.dtype}" + f" but got {out_dtype}" ) raise RuntimeError(msg) - new_grads.append( - torch.ones_like(out, memory_format=torch.preserve_format) - ) + if isinstance(out, graph.GradientEdge): + assert out_size is not None + assert out_device is not None + new_grads.append( + torch.ones( + out_size, + dtype=out_dtype, + device=out_device, + ) + ) + else: + assert isinstance(out, torch.Tensor) + new_grads.append( + torch.ones_like(out, memory_format=torch.preserve_format) + ) else: new_grads.append(None) else: @@ -297,7 +355,7 @@ def backward( def grad( - outputs: _TensorOrTensors, + outputs: _TensorOrTensorsOrGradEdge, inputs: _TensorOrTensorsOrGradEdge, grad_outputs: Optional[_TensorOrTensors] = None, retain_graph: Optional[bool] = None, @@ -327,7 +385,7 @@ def grad( ``torch.autograd.backward``. Args: - outputs (sequence of Tensor): outputs of the differentiated function. + outputs (sequence of Tensor or GradientEdge): outputs of the differentiated function. inputs (sequence of Tensor or GradientEdge): Inputs w.r.t. which the gradient will be returned (and not accumulated into ``.grad``). grad_outputs (sequence of Tensor): The "vector" in the vector-Jacobian product. @@ -369,21 +427,24 @@ def grad( ) if allow_unused is None: allow_unused = materialize_grads - t_outputs = cast( - Tuple[torch.Tensor, ...], - (outputs,) if is_tensor_like(outputs) else tuple(outputs), - ) + if is_tensor_like(outputs) or isinstance(outputs, graph.GradientEdge): + outputs = cast( + Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]], (outputs,) + ) + else: + outputs = tuple(outputs) if is_tensor_like(inputs) or isinstance(inputs, graph.GradientEdge): inputs = cast(_TensorOrTensorsOrGradEdge, (inputs,)) else: inputs = tuple(inputs) + t_outputs = tuple(i for i in outputs if is_tensor_like(i)) t_inputs = tuple(i for i in inputs if is_tensor_like(i)) overridable_args = t_outputs + t_inputs if has_torch_function(overridable_args): return handle_torch_function( grad, overridable_args, - t_outputs, + outputs, inputs, grad_outputs=grad_outputs, retain_graph=retain_graph, @@ -403,9 +464,9 @@ def grad( stacklevel=2, ) - grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(t_outputs)) + grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(outputs)) grad_outputs_ = _make_grads( - t_outputs, grad_outputs_, is_grads_batched=is_grads_batched + outputs, grad_outputs_, is_grads_batched=is_grads_batched ) if retain_graph is None: @@ -418,7 +479,7 @@ def grad( def vjp(gO): return _engine_run_backward( - t_outputs, + outputs, gO, retain_graph, create_graph, @@ -432,7 +493,7 @@ def grad( ) else: result = _engine_run_backward( - t_outputs, + outputs, grad_outputs_, retain_graph, create_graph, diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index 5482120da43..7f9a38aaa08 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -79,6 +79,11 @@ class Node(abc.ABC): r"""Return the metadata.""" raise NotImplementedError + @property + @abc.abstractmethod + def _input_metadata(self) -> List[Any]: + raise NotImplementedError + @abc.abstractmethod def _register_hook_dict(self, tensor: torch.Tensor) -> None: raise NotImplementedError @@ -170,7 +175,9 @@ class Node(abc.ABC): return NotImplemented -def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Node: +def _get_grad_fn_or_grad_acc(t: Union[torch.Tensor, "GradientEdge"]) -> Node: + if isinstance(t, GradientEdge): + return t.node if t.requires_grad and t.grad_fn is None: node = t.view_as(t).grad_fn.next_functions[0][0] # type: ignore[union-attr] else: @@ -738,7 +745,7 @@ def allow_mutation_on_saved_tensors() -> ( def _register_logging_hooks_on_whole_graph( - t_outputs: Sequence[torch.Tensor], + t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], ) -> Callable[[], None]: grad_fns = list(map(_get_grad_fn_or_grad_acc, t_outputs)) @@ -788,7 +795,7 @@ def _register_logging_hooks_on_whole_graph( def _engine_run_backward( - t_outputs: Sequence[torch.Tensor], + t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], *args: Any, **kwargs: Any, ) -> Tuple[torch.Tensor, ...]: diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index df731a64d8b..8a0c75cd781 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -184,6 +185,23 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { .value("NO_GRAD_MODE", CreationMeta::NO_GRAD_MODE) .value("INFERENCE_MODE", CreationMeta::INFERENCE_MODE); + py::class_(m, "_InputMetadata") + .def_property_readonly( + "dtype", + [](const torch::autograd::InputMetadata& m) { + PyObject* raw_obj = + (PyObject*)torch::getTHPDtype(m.dtype().toScalarType()); + return py::reinterpret_borrow(raw_obj); + }) + .def_property_readonly("device", &torch::autograd::InputMetadata::device) + .def_property_readonly( + "shape", &torch::autograd::InputMetadata::shape_as_dim_vector) + .def_property_readonly( + "is_nested_tensor", &torch::autograd::InputMetadata::is_nested_tensor) + .def_property_readonly( + "is_cpp_nested_tensor", + &torch::autograd::InputMetadata::is_cpp_nested_tensor); + py::class_(m, "_KinetoEvent") // name of the event .def("name", [](const KinetoEvent& e) { return e.name(); }) diff --git a/torch/csrc/autograd/python_cpp_function.cpp b/torch/csrc/autograd/python_cpp_function.cpp index ec057f91df4..e17ff88cce1 100644 --- a/torch/csrc/autograd/python_cpp_function.cpp +++ b/torch/csrc/autograd/python_cpp_function.cpp @@ -210,6 +210,27 @@ PyObject* THPCppFunction_set_sequence_nr( END_HANDLE_TH_ERRORS } +PyObject* THPCppFunction_input_metadata(PyObject* self, void* closure) { + HANDLE_TH_ERRORS; + auto& fn = *((THPCppFunction*)self)->cdata; + const auto num_inputs = + fn.num_inputs(); // Assuming there's a method to get the number of inputs + THPObjectPtr list(PyTuple_New(num_inputs)); + if (!list) { + return nullptr; + } + for (size_t i = 0; i < num_inputs; ++i) { + const auto& metadata = fn.input_metadata(i); + THPObjectPtr item(py::cast(metadata).release().ptr()); + if (!item) { + return nullptr; + } + PyTuple_SET_ITEM(list.get(), i, item.release()); + } + return list.release(); + END_HANDLE_TH_ERRORS +} + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) static struct PyMethodDef default_methods[] = { THP_FUNCTION_DEFAULT_METHODS, diff --git a/torch/csrc/autograd/python_cpp_function.h b/torch/csrc/autograd/python_cpp_function.h index bd81a28334b..832ab1c7677 100644 --- a/torch/csrc/autograd/python_cpp_function.h +++ b/torch/csrc/autograd/python_cpp_function.h @@ -51,19 +51,21 @@ PyObject* CppFunction_pynew( (char*)"_set_sequence_nr", THPCppFunction_set_sequence_nr, METH_O, nullptr \ } -#define THP_FUNCTION_DEFAULT_PROPERTIES \ - {(char*)"next_functions", \ - THPCppFunction_next_functions, \ - nullptr, \ - nullptr, \ - nullptr}, \ - {(char*)"requires_grad", \ - THPCppFunction_requires_grad, \ - nullptr, \ - nullptr, \ - nullptr}, \ - { \ - (char*)"metadata", THPCppFunction_metadata, nullptr, nullptr, nullptr \ +#define THP_FUNCTION_DEFAULT_PROPERTIES \ + {(char*)"next_functions", \ + THPCppFunction_next_functions, \ + nullptr, \ + nullptr, \ + nullptr}, \ + {(char*)"requires_grad", \ + THPCppFunction_requires_grad, \ + nullptr, \ + nullptr, \ + nullptr}, \ + {(char*)"metadata", THPCppFunction_metadata, nullptr, nullptr, nullptr}, \ + { \ + (char*)"_input_metadata", THPCppFunction_input_metadata, nullptr, nullptr, \ + nullptr \ } PyObject* THPCppFunction_next_functions(PyObject* self, void* _unused); @@ -75,6 +77,7 @@ PyObject* THPCppFunction_register_prehook(PyObject* self, PyObject* hook); PyObject* THPCppFunction_name(PyObject* self, PyObject* noargs); PyObject* THPCppFunction_sequence_nr(PyObject* self, PyObject* noargs); +PyObject* THPCppFunction_input_metadata(PyObject* self, void* _unused); PyTypeObject* _initFunctionPyTypeObject( PyTypeObject& type, diff --git a/torch/csrc/autograd/python_engine.cpp b/torch/csrc/autograd/python_engine.cpp index 0d7265370c4..38646701ebd 100644 --- a/torch/csrc/autograd/python_engine.cpp +++ b/torch/csrc/autograd/python_engine.cpp @@ -166,6 +166,24 @@ c10::intrusive_ptr PythonEngine::execute_with_graph_task( PyObject* THPEngineClass = nullptr; +inline static Edge parseGradientEdge(PyObject* obj, int64_t index) { + PyObject* grad_fn = PyTuple_GetItem(obj, 0); + auto output_nr = THPUtils_unpackLong(PyTuple_GetItem(obj, 1)); + std::shared_ptr grad_fn_sp; + if (THPFunction_Check(grad_fn)) { + grad_fn_sp = ((THPFunction*)grad_fn)->cdata.lock(); + } else if (THPCppFunction_Check(grad_fn)) { + grad_fn_sp = ((THPCppFunction*)grad_fn)->cdata; + } else { + TORCH_CHECK( + false, + "GradientEdge's first object must be an autograd.graph.Node " + "but got ", + THPUtils_typename(grad_fn)); + } + return Edge(grad_fn_sp, output_nr); +} + // Implementation of torch._C._EngineBase.run_backward PyObject* THPEngine_run_backward( PyObject* self, @@ -239,22 +257,29 @@ PyObject* THPEngine_run_backward( grads.reserve(num_tensors); for (const auto i : c10::irange(num_tensors)) { PyObject* _tensor = PyTuple_GET_ITEM(tensors, i); - TORCH_CHECK( - THPVariable_Check(_tensor), - "element ", - i, - " of tensors tuple is not a Tensor"); - const auto& variable = THPVariable_Unpack(_tensor); - TORCH_CHECK( - !isBatchedTensor(variable), - "torch.autograd.grad(outputs, inputs, grad_outputs) called inside ", - "torch.vmap. We do not support the case where any outputs are ", - "vmapped tensors (output ", - i, - " is being vmapped over). Please " - "call autograd.grad() outside torch.vmap or file a bug report " - "with your use case.") - auto gradient_edge = torch::autograd::impl::gradient_edge(variable); + Edge gradient_edge; // Temporary variable to hold the gradient edge + c10::optional mb_output; + if (THPVariable_Check(_tensor)) { + mb_output = THPVariable_Unpack(_tensor); + TORCH_CHECK( + !isBatchedTensor(mb_output.value()), + "torch.autograd.grad(outputs, inputs, grad_outputs) called inside ", + "torch.vmap. We do not support the case where any outputs are ", + "vmapped tensors (output ", + i, + " is being vmapped over). Please " + "call autograd.grad() outside torch.vmap or file a bug report " + "with your use case."); + gradient_edge = torch::autograd::impl::gradient_edge(mb_output.value()); + } else if (PyObject_IsInstance(_tensor, THPGradientEdgeClass)) { + gradient_edge = parseGradientEdge(_tensor, i); + } else { + TORCH_CHECK( + false, + "element ", + i, + " of tensors tuple is neither a Tensor nor a GradientEdge"); + } TORCH_CHECK( gradient_edge.function, "element ", @@ -281,7 +306,13 @@ PyObject* THPEngine_run_backward( i, " of gradients tuple is not a Tensor or None"); TORCH_CHECK( - !variable.requires_grad(), + mb_output.has_value(), + "element ", + i, + " of gradients tuple is None, but the corresponding output is a GradientEdge." + "This is not supported."); + TORCH_CHECK( + !mb_output.value().requires_grad(), "element ", i, " of gradients tuple is None, but the corresponding Tensor requires grad"); @@ -330,23 +361,7 @@ PyObject* THPEngine_run_backward( output_edges.emplace_back(grad_fn, output_nr); } } else if (PyObject_IsInstance(input, THPGradientEdgeClass)) { - auto node = PyTuple_GetItem(input, 0); - bool isTHPFunction = THPFunction_Check(node); - bool isTHPCppFunction = THPCppFunction_Check(node); - TORCH_CHECK( - isTHPFunction || isTHPCppFunction, - "GradientEdge first object must be an autograd.graph.Node " - "but got ", - THPUtils_typename(node)); - std::shared_ptr node_sp; - if (isTHPFunction) { - node_sp = ((THPFunction*)node)->cdata.lock(); - } else { - node_sp = ((torch::autograd::THPCppFunction*)node)->cdata; - } - - auto output_nr = THPUtils_unpackUInt32(PyTuple_GetItem(input, 1)); - output_edges.emplace_back(node_sp, output_nr); + output_edges.emplace_back(parseGradientEdge(input, i)); } else { TORCH_CHECK( false, diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index a5ba07b2cdb..03022511cc9 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -1181,6 +1181,26 @@ PyObject* THPFunction_set_sequence_nr(PyObject* self, PyObject* sequence_nr) { END_HANDLE_TH_ERRORS } +PyObject* THPFunction_input_metadata(PyObject* self, void* unused) { + HANDLE_TH_ERRORS; + auto cdata = ((THPFunction*)self)->cdata.lock(); + const auto num_inputs = cdata->num_inputs(); + THPObjectPtr list(PyTuple_New(num_inputs)); + if (!list) { + return nullptr; + } + for (size_t i = 0; i < num_inputs; ++i) { + const auto& metadata = cdata->input_metadata(i); + THPObjectPtr item(py::cast(metadata).release().ptr()); + if (!item) { + return nullptr; + } + PyTuple_SET_ITEM(list.get(), i, item.release()); + } + return list.release(); + END_HANDLE_TH_ERRORS +} + PyObject* THPFunction_maybe_clear_saved_tensors( PyObject* self, PyObject* noargs) { @@ -1723,6 +1743,11 @@ static struct PyGetSetDef THPFunction_properties[] = { nullptr}, {"requires_grad", getRequiresGrad, nullptr, nullptr, nullptr}, {"metadata", (getter)THPFunction_metadata, nullptr, nullptr, nullptr}, + {"_input_metadata", + (getter)THPFunction_input_metadata, + nullptr, + nullptr, + nullptr}, {"materialize_grads", nullptr, (setter)THPFunction_set_materialize_grads,