mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[autograd] Support GradientEdge as output for torch.autograd.grad (#127766)
This is useful for splitting grad to run in two parts while preserving intermediates:
<details>
<summary>
Click to see code
</summary>
```python
import collections
import weakref
from torch.autograd.graph import GradientEdge
def _get_grad_fn_or_grad_acc(t):
if t.requires_grad and t.grad_fn is None:
return t.view_as(t).grad_fn.next_functions[0][0]
else:
return t.grad_fn
def reverse_closure(roots, target_nodes):
# Recurse until we reach a target node
closure = set()
actual_target_nodes = set()
q: Deque = collections.deque()
for node in roots:
if node is not None and node not in closure:
closure.add(node)
q.append(node)
while q:
node = q.popleft()
reverse_edges = node.metadata.get("reverse_edges", [])
for holder_ref, idx in reverse_edges:
ref = holder_ref()
if ref is not None:
raise RuntimeError("Reverse graph is no longer alive")
fn = ref.node
if fn in closure or fn is None:
continue
if fn in target_nodes:
actual_target_nodes.add(fn)
continue
closure.add(fn)
q.append(fn)
return closure, actual_target_nodes
# Enable weak pointer
class Holder():
def __init__(self, node):
self.node = node
# TODO: use weak references to avoid reference cycle
def construct_reverse_graph(roots):
q: Deque = collections.deque()
root_seen = set()
reverse_graph_refs = []
for node in roots:
if node is not None and node not in root_seen:
q.append(node)
root_seen.add(node)
while q:
node = q.popleft()
for fn, idx in node.next_functions:
if fn is not None:
# Don't necessarily need to store on the graph
reverse_edges = fn.metadata.get("reverse_edges", [])
if len(reverse_edges) == 0:
q.append(fn)
holder = Holder(node)
holder_ref = weakref.ref(holder)
reverse_graph_refs.append(holder)
reverse_edges.append((holder_ref, idx))
fn.metadata["reverse_edges"] = reverse_edges
return reverse_graph_refs
def get_param_groups(inputs, params):
inputs_closure, _ = reverse_closure(inputs, set())
param_groups = dict() # keyed on intermediates
for i, param in enumerate(params):
closure, intersected = reverse_closure([param], inputs_closure)
param_group = {
"params": set([param]),
"intermediates": set(intersected),
}
for input_node in intersected:
existing = param_groups.get(input_node, None)
if existing is not None:
existing["params"] = existing["params"].union(param_group["params"])
existing["intermediates"] = existing["intermediates"].union(param_group["intermediates"])
param_group = existing
else:
param_groups[input_node] = param_group
# Sanity check: union of all param_groups params should be equal to all params
union_params = set()
seen_ids = set()
unique_param_groups = []
for param_group in param_groups.values():
if id(param_group) not in seen_ids:
seen_ids.add(id(param_group))
unique_param_groups.append(param_group)
union_params = union_params.union(param_group["params"])
assert union_params == set(params)
return unique_param_groups
def compute_grads_only_inputs2(roots, inps, weights):
root_grad_fns = list(map(_get_grad_fn_or_grad_acc, roots))
inp_grad_fns = list(map(_get_grad_fn_or_grad_acc, inps))
weight_grad_fns = list(map(_get_grad_fn_or_grad_acc, weights))
reverse_graph_refs = construct_reverse_graph(root_grad_fns)
param_groups = get_param_groups(inp_grad_fns, weight_grad_fns)
del reverse_graph_refs
for param_group in param_groups:
for i, intermediate in enumerate(param_group["intermediates"]):
def get_hook(param_group, i):
def hook(grad_inputs):
if param_group.get("grads", None) is None:
param_group["grads"] = [None] * len(param_group["intermediates"])
param_group["grads"][i] = grad_inputs
return hook
# These are always "split" nodes that we need to recompute, so
# save their inputs.
intermediate.register_prehook(get_hook(param_group, i))
dinputs = torch.autograd.grad((out,), inputs=tuple(inps), grad_outputs=(torch.ones_like(out),), retain_graph=True)
return dinputs, param_groups
def compute_grads_only_weights2(user_weights, param_groups):
all_dweights = dict()
for param_group in param_groups:
# TODO: Handle case where intermediate can have multiple outputs
intermediate_edges = tuple(GradientEdge(i, 0) for i in param_group["intermediates"])
weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"])
assert all(len(g) == 1 for g in param_group["grads"])
# [NEW!] Able to pass a GradientEdge to autograd.grad as output
# We do not need to retain_graph because... guarantee no overlap?
print("trying to execute: ", intermediate_edges, weights_edges)
dweights = torch.autograd.grad(intermediate_edges, weights_edges, grad_outputs=sum(param_group["grads"], tuple()))
for w, dw in zip(param_group["params"], dweights):
all_dweights[w] = dw
# return grads in the original order weights were provided in
out = []
for w in user_weights:
grad_acc = _get_grad_fn_or_grad_acc(w)
out.append(all_dweights[grad_acc])
return tuple(out)
```
</details>
```python
import torch.nn as nn
# Setup
mod1 = nn.Linear(10, 10)
mod2 = nn.Linear(10, 10)
a = torch.rand(10, requires_grad=True)
weights = tuple(mod1.parameters()) + tuple(mod2.parameters())
inps = (a,)
out = mod2(mod1(a))
class LoggingTensorMode(torch.utils._python_dispatch.TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
rs = func(*args, **kwargs)
print(f"{func.__module__}.{func.__name__}")
return rs
print(" -- SPLIT -- ")
# Compute gradients in two parts
with LoggingTensorMode():
print("PART 1")
dinputs, state = compute_grads_only_inputs2((out,), inps, weights)
print("PART 2")
dweights = compute_grads_only_weights2(weights, state)
out = mod2(mod1(a))
print(" -- REF -- ")
# Compare with reference
with LoggingTensorMode():
ref_all_gradients = torch.autograd.grad(out, inputs=tuple(inps) + weights, grad_outputs=(torch.ones_like(out),))
for actual, ref in zip(dinputs + dweights, ref_all_gradients):
print(torch.allclose(actual, ref))
```
<img width="598" alt="image" src="https://github.com/pytorch/pytorch/assets/13428986/3681b8a7-3ab4-4d1d-a836-abef6913e671">
```
PART 1
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.ones_like.default
V0603 10:17:21.590878 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1ee160> with grad_outputs: [f32[10]]
torch._ops.aten.view.default
V0603 10:17:21.591204 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1ee0d0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
V0603 10:17:21.591578 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x100d7ae50> with grad_outputs: [f32[1, 10]]
torch._ops.aten.view.default
V0603 10:17:21.591747 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1e4a60> with grad_outputs: [f32[10]]
torch._ops.aten.view.default
V0603 10:17:21.591834 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1e4bb0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
V0603 10:17:21.591922 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1e4a90> with grad_outputs: [f32[1, 10]]
torch._ops.aten.view.default
PART 2
trying to execute: (GradientEdge(node=<AddmmBackward0 object at 0x12a1e4bb0>, output_nr=0),) (GradientEdge(node=<AccumulateGrad object at 0x12a21b130>, output_nr=0), GradientEdge(node=<AccumulateGrad object at 0x12a21b7c0>, output_nr=0))
V0603 10:17:21.592223 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1e4bb0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
torch._ops.aten.t.default
torch._ops.aten.sum.dim_IntList
torch._ops.aten.view.default
V0603 10:17:21.592421 8300067520 torch/autograd/graph.py:751] Executing: <TBackward0 object at 0x12a1cad60> with grad_outputs: [f32[10, 10]]
torch._ops.aten.t.default
trying to execute: (GradientEdge(node=<AddmmBackward0 object at 0x12a1ee0d0>, output_nr=0),) (GradientEdge(node=<AccumulateGrad object at 0x12a1e41c0>, output_nr=0), GradientEdge(node=<AccumulateGrad object at 0x12a21b670>, output_nr=0))
V0603 10:17:21.593481 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1ee0d0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
torch._ops.aten.t.default
torch._ops.aten.sum.dim_IntList
torch._ops.aten.view.default
V0603 10:17:21.593750 8300067520 torch/autograd/graph.py:751] Executing: <TBackward0 object at 0x12a21b2b0> with grad_outputs: [f32[10, 10]]
torch._ops.aten.t.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127766
Approved by: https://github.com/albanD
This commit is contained in:
parent
c1e7e40f24
commit
2eec02523b
10 changed files with 359 additions and 75 deletions
|
|
@ -2277,6 +2277,7 @@ known_failing_tests = {
|
|||
"test_grad_materialize_grads", # RuntimeError: compiled_autograd does not support create_graph
|
||||
"test_grad_nonleaf", # RuntimeError: compiled_autograd does not support create_graph
|
||||
"test_grad_nonleaf_many_outputs", # RuntimeError: compiled_autograd does not support create_graph
|
||||
"test_gradient_edge_output", # RuntimeError: trying to backward through the graph a second time
|
||||
"test_hessian_vector", # RuntimeError: compiled_autograd does not support create_graph
|
||||
"test_hook_closure_cycle_use_custom_function_True_use_tensor_hook_False", # AttributeError: type object
|
||||
"test_hook_closure_cycle_use_custom_function_True_use_tensor_hook_True", # AttributeError: type object
|
||||
|
|
|
|||
|
|
@ -436,6 +436,7 @@ class TestProfilerTree(TestCase):
|
|||
[memory]""",
|
||||
)
|
||||
|
||||
@unittest.skip("https://github.com/pytorch/pytorch/issues/83606")
|
||||
@unittest.skipIf(
|
||||
TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite."
|
||||
)
|
||||
|
|
@ -478,8 +479,19 @@ class TestProfilerTree(TestCase):
|
|||
<built-in function len>
|
||||
torch/autograd/__init__.py(...): _tensor_or_tensors_to_tuple
|
||||
torch/autograd/__init__.py(...): _make_grads
|
||||
typing.py(...): inner
|
||||
typing.py(...): __hash__
|
||||
<built-in function hash>
|
||||
typing.py(...): cast
|
||||
<built-in function isinstance>
|
||||
<built-in function isinstance>
|
||||
<built-in function isinstance>
|
||||
<built-in function isinstance>
|
||||
<built-in function isinstance>
|
||||
<built-in function isinstance>
|
||||
<built-in method numel of Tensor object at 0xXXXXXXXXXXXX>
|
||||
<built-in function isinstance>
|
||||
<built-in function isinstance>
|
||||
<built-in method ones_like of type object at 0xXXXXXXXXXXXX>
|
||||
aten::ones_like
|
||||
aten::empty_like
|
||||
|
|
@ -910,6 +922,9 @@ class TestProfilerTree(TestCase):
|
|||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
|
||||
@ProfilerTree.test
|
||||
def test_profiler_experimental_tree_cuda_detailed(self):
|
||||
# Do lazy imports ahead of time to avoid it showing up in the tree
|
||||
import torch.nested._internal.nested_tensor
|
||||
|
||||
model = torch.nn.modules.Linear(1, 1, device="cuda")
|
||||
model.train()
|
||||
opt = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
|
||||
|
|
@ -967,8 +982,19 @@ class TestProfilerTree(TestCase):
|
|||
<built-in function len>
|
||||
torch/autograd/__init__.py(...): _tensor_or_tensors_to_tuple
|
||||
torch/autograd/__init__.py(...): _make_grads
|
||||
typing.py(...): inner
|
||||
typing.py(...): __hash__
|
||||
<built-in function hash>
|
||||
typing.py(...): cast
|
||||
<built-in function isinstance>
|
||||
<built-in function isinstance>
|
||||
<built-in function isinstance>
|
||||
<built-in function isinstance>
|
||||
<built-in function isinstance>
|
||||
<built-in function isinstance>
|
||||
<built-in method numel of Tensor object at 0xXXXXXXXXXXXX>
|
||||
<built-in function isinstance>
|
||||
<built-in function isinstance>
|
||||
<built-in method ones_like of type object at 0xXXXXXXXXXXXX>
|
||||
aten::ones_like
|
||||
aten::empty_like
|
||||
|
|
|
|||
|
|
@ -985,6 +985,113 @@ class TestAutograd(TestCase):
|
|||
torch.autograd.backward(out.sum(), inputs=(x, edge_y))
|
||||
torch.autograd.backward(out.sum(), inputs=(edge_x, edge_y))
|
||||
|
||||
def test_grad_fn_input_metadata(self):
|
||||
x = torch.rand(2, requires_grad=True, dtype=torch.float32)
|
||||
y = torch.rand(2, requires_grad=True, dtype=torch.float32)
|
||||
z = x * y
|
||||
z_metadata = z.grad_fn._input_metadata[0]
|
||||
self.assertEqual(z_metadata.shape, (2,))
|
||||
self.assertEqual(z_metadata.dtype, torch.float32)
|
||||
|
||||
# Multiple outputs
|
||||
b = torch.rand(3, 3, requires_grad=True)
|
||||
var, _ = torch.var_mean(b, dim=0)
|
||||
|
||||
metadata_0 = var.grad_fn._input_metadata[0]
|
||||
metadata_1 = var.grad_fn._input_metadata[1]
|
||||
self.assertEqual(metadata_0.shape, (3,))
|
||||
self.assertEqual(metadata_1.shape, (3,))
|
||||
|
||||
# Preserves symints
|
||||
nt = torch.nested.nested_tensor(
|
||||
[
|
||||
torch.randn(
|
||||
3,
|
||||
2,
|
||||
),
|
||||
torch.randn(
|
||||
2,
|
||||
2,
|
||||
),
|
||||
],
|
||||
layout=torch.jagged,
|
||||
requires_grad=True,
|
||||
)
|
||||
nt_metadata = nt.clone().grad_fn._input_metadata[0]
|
||||
|
||||
self.assertIsInstance(nt_metadata.shape[1], torch.SymInt)
|
||||
self.assertEqual(nt_metadata.shape, nt.shape)
|
||||
self.assertTrue(nt_metadata.is_nested_tensor)
|
||||
self.assertFalse(nt_metadata.is_cpp_nested_tensor)
|
||||
self.assertEqual(nt_metadata.dtype, nt.dtype)
|
||||
|
||||
class Test(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output
|
||||
|
||||
x = torch.randn(3, 3, requires_grad=True)
|
||||
x = Test.apply(x)
|
||||
metadata = x.grad_fn._input_metadata[0]
|
||||
self.assertEqual(metadata.shape, (3, 3))
|
||||
|
||||
def test_gradient_edge_output(self):
|
||||
x = torch.tensor([1.0, 2.0], requires_grad=True)
|
||||
|
||||
def fn(x, reduce=True):
|
||||
tmp = x.sin().cos()
|
||||
if reduce:
|
||||
tmp = tmp.sum()
|
||||
out = tmp.exp().clone().sin().sum()
|
||||
tmp_edge = torch.autograd.graph.get_gradient_edge(tmp)
|
||||
return out, tmp_edge
|
||||
|
||||
# Compute fn backward in two steps
|
||||
out, tmp_edge = fn(x)
|
||||
(tmp_grad,) = torch.autograd.grad(out, (tmp_edge,))
|
||||
|
||||
(x_grad,) = torch.autograd.grad(tmp_edge, (x,), grad_outputs=(tmp_grad,))
|
||||
|
||||
# Compare with as if we did it in one go.
|
||||
out, _ = fn(x)
|
||||
(x_grad_ref,) = torch.autograd.grad(out, (x,))
|
||||
self.assertEqual(x_grad, x_grad_ref)
|
||||
|
||||
# Incorrect case: grad_outputs not passed/implicitly None and output is
|
||||
# not a scalar
|
||||
out, tmp_edge = fn(x, reduce=False)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"grad can be implicitly created only for scalar output",
|
||||
):
|
||||
torch.autograd.grad(tmp_edge, (x,))
|
||||
|
||||
# grad_outputs is None, and output is a scalar is fine
|
||||
out, tmp_edge = fn(x, reduce=True)
|
||||
torch.autograd.grad(tmp_edge, (x,))
|
||||
|
||||
# Incorrect case: grad_outputs wrong size
|
||||
out, tmp_edge = fn(x)
|
||||
(tmp_grad,) = torch.autograd.grad(out, (tmp_edge,))
|
||||
with self.assertRaisesRegex(RuntimeError, "Mismatch in shape"):
|
||||
torch.autograd.grad(
|
||||
tmp_edge, (x,), grad_outputs=torch.tensor([1.0, 2.0, 3.0, 4.0])
|
||||
)
|
||||
|
||||
# Incorrect case: wrong dtype
|
||||
out, tmp_edge = fn(x)
|
||||
(tmp_grad,) = torch.autograd.grad(out, (tmp_edge,))
|
||||
with self.assertRaisesRegex(RuntimeError, "required to have the same dtype"):
|
||||
torch.autograd.grad(
|
||||
tmp_edge,
|
||||
(x,),
|
||||
grad_outputs=torch.rand_like(tmp_grad, dtype=torch.complex64),
|
||||
)
|
||||
|
||||
def test_grad_nonleaf(self):
|
||||
x_init = torch.randn(2, 2, requires_grad=True)
|
||||
x = x_init
|
||||
|
|
|
|||
|
|
@ -56,12 +56,21 @@ _ShapeorNestedShape = Union[_size, Sequence[_size], torch.Tensor]
|
|||
|
||||
|
||||
def _calculate_shape(
|
||||
output: torch.Tensor, grad: torch.Tensor, is_grads_batched: bool
|
||||
output: Union[torch.Tensor, graph.GradientEdge],
|
||||
grad: torch.Tensor,
|
||||
is_grads_batched: bool,
|
||||
) -> Tuple[_ShapeorNestedShape, _ShapeorNestedShape]:
|
||||
# is_same_size ensures that both tensors are either nested or non nested
|
||||
# circular import
|
||||
from torch.nested._internal.nested_tensor import NestedTensor
|
||||
|
||||
if isinstance(output, graph.GradientEdge):
|
||||
# We have already checked that we are not a C++ NestedTensor
|
||||
if is_grads_batched:
|
||||
raise RuntimeError("Batched grads are not supported with GradientEdge")
|
||||
out_metadata = output.node._input_metadata[output.output_nr]
|
||||
return torch.Size(out_metadata.shape), grad.shape
|
||||
|
||||
if output.is_nested and not isinstance(output, NestedTensor):
|
||||
if is_grads_batched:
|
||||
raise RuntimeError("Batched grads are not supported with Nested Tensor.")
|
||||
|
|
@ -76,27 +85,58 @@ def _calculate_shape(
|
|||
|
||||
|
||||
def _make_grads(
|
||||
outputs: Sequence[torch.Tensor],
|
||||
outputs: Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]],
|
||||
grads: Sequence[_OptionalTensor],
|
||||
is_grads_batched: bool,
|
||||
) -> Tuple[_OptionalTensor, ...]:
|
||||
new_grads: List[_OptionalTensor] = []
|
||||
for out, grad in zip(outputs, grads):
|
||||
out = cast(Union[torch.Tensor, graph.GradientEdge], out)
|
||||
out_size = None
|
||||
out_device = None
|
||||
|
||||
if isinstance(out, graph.GradientEdge):
|
||||
out_metadata = out.node._input_metadata[out.output_nr]
|
||||
out_size = torch.Size(out_metadata.shape)
|
||||
out_dtype = out_metadata.dtype
|
||||
out_device = out_metadata.device
|
||||
out_is_nested = out_metadata.is_nested_tensor
|
||||
if out_metadata.is_cpp_nested_tensor:
|
||||
raise RuntimeError(
|
||||
"C++ NestedTensor are not supported with GradientEdge"
|
||||
)
|
||||
out_is_cpp_nested = False
|
||||
else:
|
||||
# circular import
|
||||
from torch.nested._internal.nested_tensor import NestedTensor
|
||||
|
||||
assert isinstance(out, torch.Tensor)
|
||||
out_dtype = out.dtype
|
||||
out_is_nested = out.is_nested
|
||||
out_is_cpp_nested = out_is_nested and not isinstance(out, NestedTensor)
|
||||
if not out_is_cpp_nested:
|
||||
out_size = out.shape
|
||||
|
||||
if isinstance(grad, torch.Tensor):
|
||||
from torch.fx.experimental.symbolic_shapes import expect_true, sym_eq
|
||||
|
||||
first_grad = grad if not is_grads_batched else grad[0]
|
||||
|
||||
# TODO: We can remove this conditional once we uniformly use
|
||||
# singleton int to represent jagged dimension, so that size() call
|
||||
# on nested tensor works
|
||||
if out.is_nested or first_grad.is_nested:
|
||||
# on nested tensor works.
|
||||
if out_is_cpp_nested:
|
||||
assert isinstance(out, torch.Tensor)
|
||||
shape_matches = torch.is_same_size(out, first_grad)
|
||||
else:
|
||||
# We need to do a regular size check, without going through
|
||||
# the operator, to be able to handle unbacked symints
|
||||
# (expect_true ensures we can deal with unbacked)
|
||||
shape_matches = expect_true(sym_eq(out.size(), first_grad.size()))
|
||||
assert out_size is not None
|
||||
shape_matches = expect_true(sym_eq(out_size, first_grad.size()))
|
||||
|
||||
if not shape_matches:
|
||||
out = cast(Union[torch.Tensor, graph.GradientEdge], out)
|
||||
out_shape, grad_shape = _calculate_shape(
|
||||
out, first_grad, is_grads_batched
|
||||
)
|
||||
|
|
@ -130,7 +170,7 @@ def _make_grads(
|
|||
+ str(out_shape)
|
||||
+ "."
|
||||
)
|
||||
if out.dtype.is_complex != grad.dtype.is_complex:
|
||||
if out_dtype.is_complex != grad.dtype.is_complex:
|
||||
raise RuntimeError(
|
||||
"For complex Tensors, both grad_output and output"
|
||||
" are required to have the same dtype."
|
||||
|
|
@ -141,25 +181,43 @@ def _make_grads(
|
|||
+ " and output["
|
||||
+ str(outputs.index(out))
|
||||
+ "] has a dtype of "
|
||||
+ str(out.dtype)
|
||||
+ str(out_dtype)
|
||||
+ "."
|
||||
)
|
||||
new_grads.append(grad)
|
||||
elif grad is None:
|
||||
if out.requires_grad:
|
||||
if out.numel() != 1:
|
||||
if isinstance(out, graph.GradientEdge) or out.requires_grad: # type: ignore[attr-defined]
|
||||
if isinstance(out, graph.GradientEdge):
|
||||
assert out_size is not None
|
||||
out_numel_is_1 = all(o == 1 for o in out_size)
|
||||
else:
|
||||
assert isinstance(out, torch.Tensor)
|
||||
out_numel_is_1 = out.numel() == 1
|
||||
if not out_numel_is_1:
|
||||
raise RuntimeError(
|
||||
"grad can be implicitly created only for scalar outputs"
|
||||
)
|
||||
if not out.dtype.is_floating_point:
|
||||
if not out_dtype.is_floating_point:
|
||||
msg = (
|
||||
"grad can be implicitly created only for real scalar outputs"
|
||||
f" but got {out.dtype}"
|
||||
f" but got {out_dtype}"
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
new_grads.append(
|
||||
torch.ones_like(out, memory_format=torch.preserve_format)
|
||||
)
|
||||
if isinstance(out, graph.GradientEdge):
|
||||
assert out_size is not None
|
||||
assert out_device is not None
|
||||
new_grads.append(
|
||||
torch.ones(
|
||||
out_size,
|
||||
dtype=out_dtype,
|
||||
device=out_device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert isinstance(out, torch.Tensor)
|
||||
new_grads.append(
|
||||
torch.ones_like(out, memory_format=torch.preserve_format)
|
||||
)
|
||||
else:
|
||||
new_grads.append(None)
|
||||
else:
|
||||
|
|
@ -297,7 +355,7 @@ def backward(
|
|||
|
||||
|
||||
def grad(
|
||||
outputs: _TensorOrTensors,
|
||||
outputs: _TensorOrTensorsOrGradEdge,
|
||||
inputs: _TensorOrTensorsOrGradEdge,
|
||||
grad_outputs: Optional[_TensorOrTensors] = None,
|
||||
retain_graph: Optional[bool] = None,
|
||||
|
|
@ -327,7 +385,7 @@ def grad(
|
|||
``torch.autograd.backward``.
|
||||
|
||||
Args:
|
||||
outputs (sequence of Tensor): outputs of the differentiated function.
|
||||
outputs (sequence of Tensor or GradientEdge): outputs of the differentiated function.
|
||||
inputs (sequence of Tensor or GradientEdge): Inputs w.r.t. which the gradient will be
|
||||
returned (and not accumulated into ``.grad``).
|
||||
grad_outputs (sequence of Tensor): The "vector" in the vector-Jacobian product.
|
||||
|
|
@ -369,21 +427,24 @@ def grad(
|
|||
)
|
||||
if allow_unused is None:
|
||||
allow_unused = materialize_grads
|
||||
t_outputs = cast(
|
||||
Tuple[torch.Tensor, ...],
|
||||
(outputs,) if is_tensor_like(outputs) else tuple(outputs),
|
||||
)
|
||||
if is_tensor_like(outputs) or isinstance(outputs, graph.GradientEdge):
|
||||
outputs = cast(
|
||||
Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]], (outputs,)
|
||||
)
|
||||
else:
|
||||
outputs = tuple(outputs)
|
||||
if is_tensor_like(inputs) or isinstance(inputs, graph.GradientEdge):
|
||||
inputs = cast(_TensorOrTensorsOrGradEdge, (inputs,))
|
||||
else:
|
||||
inputs = tuple(inputs)
|
||||
t_outputs = tuple(i for i in outputs if is_tensor_like(i))
|
||||
t_inputs = tuple(i for i in inputs if is_tensor_like(i))
|
||||
overridable_args = t_outputs + t_inputs
|
||||
if has_torch_function(overridable_args):
|
||||
return handle_torch_function(
|
||||
grad,
|
||||
overridable_args,
|
||||
t_outputs,
|
||||
outputs,
|
||||
inputs,
|
||||
grad_outputs=grad_outputs,
|
||||
retain_graph=retain_graph,
|
||||
|
|
@ -403,9 +464,9 @@ def grad(
|
|||
stacklevel=2,
|
||||
)
|
||||
|
||||
grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(t_outputs))
|
||||
grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(outputs))
|
||||
grad_outputs_ = _make_grads(
|
||||
t_outputs, grad_outputs_, is_grads_batched=is_grads_batched
|
||||
outputs, grad_outputs_, is_grads_batched=is_grads_batched
|
||||
)
|
||||
|
||||
if retain_graph is None:
|
||||
|
|
@ -418,7 +479,7 @@ def grad(
|
|||
|
||||
def vjp(gO):
|
||||
return _engine_run_backward(
|
||||
t_outputs,
|
||||
outputs,
|
||||
gO,
|
||||
retain_graph,
|
||||
create_graph,
|
||||
|
|
@ -432,7 +493,7 @@ def grad(
|
|||
)
|
||||
else:
|
||||
result = _engine_run_backward(
|
||||
t_outputs,
|
||||
outputs,
|
||||
grad_outputs_,
|
||||
retain_graph,
|
||||
create_graph,
|
||||
|
|
|
|||
|
|
@ -79,6 +79,11 @@ class Node(abc.ABC):
|
|||
r"""Return the metadata."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def _input_metadata(self) -> List[Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def _register_hook_dict(self, tensor: torch.Tensor) -> None:
|
||||
raise NotImplementedError
|
||||
|
|
@ -170,7 +175,9 @@ class Node(abc.ABC):
|
|||
return NotImplemented
|
||||
|
||||
|
||||
def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Node:
|
||||
def _get_grad_fn_or_grad_acc(t: Union[torch.Tensor, "GradientEdge"]) -> Node:
|
||||
if isinstance(t, GradientEdge):
|
||||
return t.node
|
||||
if t.requires_grad and t.grad_fn is None:
|
||||
node = t.view_as(t).grad_fn.next_functions[0][0] # type: ignore[union-attr]
|
||||
else:
|
||||
|
|
@ -738,7 +745,7 @@ def allow_mutation_on_saved_tensors() -> (
|
|||
|
||||
|
||||
def _register_logging_hooks_on_whole_graph(
|
||||
t_outputs: Sequence[torch.Tensor],
|
||||
t_outputs: Sequence[Union[torch.Tensor, GradientEdge]],
|
||||
) -> Callable[[], None]:
|
||||
grad_fns = list(map(_get_grad_fn_or_grad_acc, t_outputs))
|
||||
|
||||
|
|
@ -788,7 +795,7 @@ def _register_logging_hooks_on_whole_graph(
|
|||
|
||||
|
||||
def _engine_run_backward(
|
||||
t_outputs: Sequence[torch.Tensor],
|
||||
t_outputs: Sequence[Union[torch.Tensor, GradientEdge]],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@
|
|||
#include <torch/csrc/autograd/autograd_not_implemented_fallback.h>
|
||||
#include <torch/csrc/autograd/function.h>
|
||||
#include <torch/csrc/autograd/grad_mode.h>
|
||||
#include <torch/csrc/autograd/input_metadata.h>
|
||||
#include <torch/csrc/autograd/profiler.h>
|
||||
#include <torch/csrc/autograd/profiler_python.h>
|
||||
#include <torch/csrc/autograd/python_function.h>
|
||||
|
|
@ -184,6 +185,23 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
|
|||
.value("NO_GRAD_MODE", CreationMeta::NO_GRAD_MODE)
|
||||
.value("INFERENCE_MODE", CreationMeta::INFERENCE_MODE);
|
||||
|
||||
py::class_<torch::autograd::InputMetadata>(m, "_InputMetadata")
|
||||
.def_property_readonly(
|
||||
"dtype",
|
||||
[](const torch::autograd::InputMetadata& m) {
|
||||
PyObject* raw_obj =
|
||||
(PyObject*)torch::getTHPDtype(m.dtype().toScalarType());
|
||||
return py::reinterpret_borrow<py::object>(raw_obj);
|
||||
})
|
||||
.def_property_readonly("device", &torch::autograd::InputMetadata::device)
|
||||
.def_property_readonly(
|
||||
"shape", &torch::autograd::InputMetadata::shape_as_dim_vector)
|
||||
.def_property_readonly(
|
||||
"is_nested_tensor", &torch::autograd::InputMetadata::is_nested_tensor)
|
||||
.def_property_readonly(
|
||||
"is_cpp_nested_tensor",
|
||||
&torch::autograd::InputMetadata::is_cpp_nested_tensor);
|
||||
|
||||
py::class_<KinetoEvent>(m, "_KinetoEvent")
|
||||
// name of the event
|
||||
.def("name", [](const KinetoEvent& e) { return e.name(); })
|
||||
|
|
|
|||
|
|
@ -210,6 +210,27 @@ PyObject* THPCppFunction_set_sequence_nr(
|
|||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THPCppFunction_input_metadata(PyObject* self, void* closure) {
|
||||
HANDLE_TH_ERRORS;
|
||||
auto& fn = *((THPCppFunction*)self)->cdata;
|
||||
const auto num_inputs =
|
||||
fn.num_inputs(); // Assuming there's a method to get the number of inputs
|
||||
THPObjectPtr list(PyTuple_New(num_inputs));
|
||||
if (!list) {
|
||||
return nullptr;
|
||||
}
|
||||
for (size_t i = 0; i < num_inputs; ++i) {
|
||||
const auto& metadata = fn.input_metadata(i);
|
||||
THPObjectPtr item(py::cast(metadata).release().ptr());
|
||||
if (!item) {
|
||||
return nullptr;
|
||||
}
|
||||
PyTuple_SET_ITEM(list.get(), i, item.release());
|
||||
}
|
||||
return list.release();
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
|
||||
static struct PyMethodDef default_methods[] = {
|
||||
THP_FUNCTION_DEFAULT_METHODS,
|
||||
|
|
|
|||
|
|
@ -51,19 +51,21 @@ PyObject* CppFunction_pynew(
|
|||
(char*)"_set_sequence_nr", THPCppFunction_set_sequence_nr, METH_O, nullptr \
|
||||
}
|
||||
|
||||
#define THP_FUNCTION_DEFAULT_PROPERTIES \
|
||||
{(char*)"next_functions", \
|
||||
THPCppFunction_next_functions, \
|
||||
nullptr, \
|
||||
nullptr, \
|
||||
nullptr}, \
|
||||
{(char*)"requires_grad", \
|
||||
THPCppFunction_requires_grad, \
|
||||
nullptr, \
|
||||
nullptr, \
|
||||
nullptr}, \
|
||||
{ \
|
||||
(char*)"metadata", THPCppFunction_metadata, nullptr, nullptr, nullptr \
|
||||
#define THP_FUNCTION_DEFAULT_PROPERTIES \
|
||||
{(char*)"next_functions", \
|
||||
THPCppFunction_next_functions, \
|
||||
nullptr, \
|
||||
nullptr, \
|
||||
nullptr}, \
|
||||
{(char*)"requires_grad", \
|
||||
THPCppFunction_requires_grad, \
|
||||
nullptr, \
|
||||
nullptr, \
|
||||
nullptr}, \
|
||||
{(char*)"metadata", THPCppFunction_metadata, nullptr, nullptr, nullptr}, \
|
||||
{ \
|
||||
(char*)"_input_metadata", THPCppFunction_input_metadata, nullptr, nullptr, \
|
||||
nullptr \
|
||||
}
|
||||
|
||||
PyObject* THPCppFunction_next_functions(PyObject* self, void* _unused);
|
||||
|
|
@ -75,6 +77,7 @@ PyObject* THPCppFunction_register_prehook(PyObject* self, PyObject* hook);
|
|||
|
||||
PyObject* THPCppFunction_name(PyObject* self, PyObject* noargs);
|
||||
PyObject* THPCppFunction_sequence_nr(PyObject* self, PyObject* noargs);
|
||||
PyObject* THPCppFunction_input_metadata(PyObject* self, void* _unused);
|
||||
|
||||
PyTypeObject* _initFunctionPyTypeObject(
|
||||
PyTypeObject& type,
|
||||
|
|
|
|||
|
|
@ -166,6 +166,24 @@ c10::intrusive_ptr<at::ivalue::Future> PythonEngine::execute_with_graph_task(
|
|||
|
||||
PyObject* THPEngineClass = nullptr;
|
||||
|
||||
inline static Edge parseGradientEdge(PyObject* obj, int64_t index) {
|
||||
PyObject* grad_fn = PyTuple_GetItem(obj, 0);
|
||||
auto output_nr = THPUtils_unpackLong(PyTuple_GetItem(obj, 1));
|
||||
std::shared_ptr<torch::autograd::Node> grad_fn_sp;
|
||||
if (THPFunction_Check(grad_fn)) {
|
||||
grad_fn_sp = ((THPFunction*)grad_fn)->cdata.lock();
|
||||
} else if (THPCppFunction_Check(grad_fn)) {
|
||||
grad_fn_sp = ((THPCppFunction*)grad_fn)->cdata;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"GradientEdge's first object must be an autograd.graph.Node "
|
||||
"but got ",
|
||||
THPUtils_typename(grad_fn));
|
||||
}
|
||||
return Edge(grad_fn_sp, output_nr);
|
||||
}
|
||||
|
||||
// Implementation of torch._C._EngineBase.run_backward
|
||||
PyObject* THPEngine_run_backward(
|
||||
PyObject* self,
|
||||
|
|
@ -239,22 +257,29 @@ PyObject* THPEngine_run_backward(
|
|||
grads.reserve(num_tensors);
|
||||
for (const auto i : c10::irange(num_tensors)) {
|
||||
PyObject* _tensor = PyTuple_GET_ITEM(tensors, i);
|
||||
TORCH_CHECK(
|
||||
THPVariable_Check(_tensor),
|
||||
"element ",
|
||||
i,
|
||||
" of tensors tuple is not a Tensor");
|
||||
const auto& variable = THPVariable_Unpack(_tensor);
|
||||
TORCH_CHECK(
|
||||
!isBatchedTensor(variable),
|
||||
"torch.autograd.grad(outputs, inputs, grad_outputs) called inside ",
|
||||
"torch.vmap. We do not support the case where any outputs are ",
|
||||
"vmapped tensors (output ",
|
||||
i,
|
||||
" is being vmapped over). Please "
|
||||
"call autograd.grad() outside torch.vmap or file a bug report "
|
||||
"with your use case.")
|
||||
auto gradient_edge = torch::autograd::impl::gradient_edge(variable);
|
||||
Edge gradient_edge; // Temporary variable to hold the gradient edge
|
||||
c10::optional<at::Tensor> mb_output;
|
||||
if (THPVariable_Check(_tensor)) {
|
||||
mb_output = THPVariable_Unpack(_tensor);
|
||||
TORCH_CHECK(
|
||||
!isBatchedTensor(mb_output.value()),
|
||||
"torch.autograd.grad(outputs, inputs, grad_outputs) called inside ",
|
||||
"torch.vmap. We do not support the case where any outputs are ",
|
||||
"vmapped tensors (output ",
|
||||
i,
|
||||
" is being vmapped over). Please "
|
||||
"call autograd.grad() outside torch.vmap or file a bug report "
|
||||
"with your use case.");
|
||||
gradient_edge = torch::autograd::impl::gradient_edge(mb_output.value());
|
||||
} else if (PyObject_IsInstance(_tensor, THPGradientEdgeClass)) {
|
||||
gradient_edge = parseGradientEdge(_tensor, i);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"element ",
|
||||
i,
|
||||
" of tensors tuple is neither a Tensor nor a GradientEdge");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
gradient_edge.function,
|
||||
"element ",
|
||||
|
|
@ -281,7 +306,13 @@ PyObject* THPEngine_run_backward(
|
|||
i,
|
||||
" of gradients tuple is not a Tensor or None");
|
||||
TORCH_CHECK(
|
||||
!variable.requires_grad(),
|
||||
mb_output.has_value(),
|
||||
"element ",
|
||||
i,
|
||||
" of gradients tuple is None, but the corresponding output is a GradientEdge."
|
||||
"This is not supported.");
|
||||
TORCH_CHECK(
|
||||
!mb_output.value().requires_grad(),
|
||||
"element ",
|
||||
i,
|
||||
" of gradients tuple is None, but the corresponding Tensor requires grad");
|
||||
|
|
@ -330,23 +361,7 @@ PyObject* THPEngine_run_backward(
|
|||
output_edges.emplace_back(grad_fn, output_nr);
|
||||
}
|
||||
} else if (PyObject_IsInstance(input, THPGradientEdgeClass)) {
|
||||
auto node = PyTuple_GetItem(input, 0);
|
||||
bool isTHPFunction = THPFunction_Check(node);
|
||||
bool isTHPCppFunction = THPCppFunction_Check(node);
|
||||
TORCH_CHECK(
|
||||
isTHPFunction || isTHPCppFunction,
|
||||
"GradientEdge first object must be an autograd.graph.Node "
|
||||
"but got ",
|
||||
THPUtils_typename(node));
|
||||
std::shared_ptr<torch::autograd::Node> node_sp;
|
||||
if (isTHPFunction) {
|
||||
node_sp = ((THPFunction*)node)->cdata.lock();
|
||||
} else {
|
||||
node_sp = ((torch::autograd::THPCppFunction*)node)->cdata;
|
||||
}
|
||||
|
||||
auto output_nr = THPUtils_unpackUInt32(PyTuple_GetItem(input, 1));
|
||||
output_edges.emplace_back(node_sp, output_nr);
|
||||
output_edges.emplace_back(parseGradientEdge(input, i));
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
|
|
|
|||
|
|
@ -1181,6 +1181,26 @@ PyObject* THPFunction_set_sequence_nr(PyObject* self, PyObject* sequence_nr) {
|
|||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THPFunction_input_metadata(PyObject* self, void* unused) {
|
||||
HANDLE_TH_ERRORS;
|
||||
auto cdata = ((THPFunction*)self)->cdata.lock();
|
||||
const auto num_inputs = cdata->num_inputs();
|
||||
THPObjectPtr list(PyTuple_New(num_inputs));
|
||||
if (!list) {
|
||||
return nullptr;
|
||||
}
|
||||
for (size_t i = 0; i < num_inputs; ++i) {
|
||||
const auto& metadata = cdata->input_metadata(i);
|
||||
THPObjectPtr item(py::cast(metadata).release().ptr());
|
||||
if (!item) {
|
||||
return nullptr;
|
||||
}
|
||||
PyTuple_SET_ITEM(list.get(), i, item.release());
|
||||
}
|
||||
return list.release();
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THPFunction_maybe_clear_saved_tensors(
|
||||
PyObject* self,
|
||||
PyObject* noargs) {
|
||||
|
|
@ -1723,6 +1743,11 @@ static struct PyGetSetDef THPFunction_properties[] = {
|
|||
nullptr},
|
||||
{"requires_grad", getRequiresGrad, nullptr, nullptr, nullptr},
|
||||
{"metadata", (getter)THPFunction_metadata, nullptr, nullptr, nullptr},
|
||||
{"_input_metadata",
|
||||
(getter)THPFunction_input_metadata,
|
||||
nullptr,
|
||||
nullptr,
|
||||
nullptr},
|
||||
{"materialize_grads",
|
||||
nullptr,
|
||||
(setter)THPFunction_set_materialize_grads,
|
||||
|
|
|
|||
Loading…
Reference in a new issue