mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
This is useful for splitting grad to run in two parts while preserving intermediates:
<details>
<summary>
Click to see code
</summary>
```python
import collections
import weakref
from torch.autograd.graph import GradientEdge
def _get_grad_fn_or_grad_acc(t):
if t.requires_grad and t.grad_fn is None:
return t.view_as(t).grad_fn.next_functions[0][0]
else:
return t.grad_fn
def reverse_closure(roots, target_nodes):
# Recurse until we reach a target node
closure = set()
actual_target_nodes = set()
q: Deque = collections.deque()
for node in roots:
if node is not None and node not in closure:
closure.add(node)
q.append(node)
while q:
node = q.popleft()
reverse_edges = node.metadata.get("reverse_edges", [])
for holder_ref, idx in reverse_edges:
ref = holder_ref()
if ref is not None:
raise RuntimeError("Reverse graph is no longer alive")
fn = ref.node
if fn in closure or fn is None:
continue
if fn in target_nodes:
actual_target_nodes.add(fn)
continue
closure.add(fn)
q.append(fn)
return closure, actual_target_nodes
# Enable weak pointer
class Holder():
def __init__(self, node):
self.node = node
# TODO: use weak references to avoid reference cycle
def construct_reverse_graph(roots):
q: Deque = collections.deque()
root_seen = set()
reverse_graph_refs = []
for node in roots:
if node is not None and node not in root_seen:
q.append(node)
root_seen.add(node)
while q:
node = q.popleft()
for fn, idx in node.next_functions:
if fn is not None:
# Don't necessarily need to store on the graph
reverse_edges = fn.metadata.get("reverse_edges", [])
if len(reverse_edges) == 0:
q.append(fn)
holder = Holder(node)
holder_ref = weakref.ref(holder)
reverse_graph_refs.append(holder)
reverse_edges.append((holder_ref, idx))
fn.metadata["reverse_edges"] = reverse_edges
return reverse_graph_refs
def get_param_groups(inputs, params):
inputs_closure, _ = reverse_closure(inputs, set())
param_groups = dict() # keyed on intermediates
for i, param in enumerate(params):
closure, intersected = reverse_closure([param], inputs_closure)
param_group = {
"params": set([param]),
"intermediates": set(intersected),
}
for input_node in intersected:
existing = param_groups.get(input_node, None)
if existing is not None:
existing["params"] = existing["params"].union(param_group["params"])
existing["intermediates"] = existing["intermediates"].union(param_group["intermediates"])
param_group = existing
else:
param_groups[input_node] = param_group
# Sanity check: union of all param_groups params should be equal to all params
union_params = set()
seen_ids = set()
unique_param_groups = []
for param_group in param_groups.values():
if id(param_group) not in seen_ids:
seen_ids.add(id(param_group))
unique_param_groups.append(param_group)
union_params = union_params.union(param_group["params"])
assert union_params == set(params)
return unique_param_groups
def compute_grads_only_inputs2(roots, inps, weights):
root_grad_fns = list(map(_get_grad_fn_or_grad_acc, roots))
inp_grad_fns = list(map(_get_grad_fn_or_grad_acc, inps))
weight_grad_fns = list(map(_get_grad_fn_or_grad_acc, weights))
reverse_graph_refs = construct_reverse_graph(root_grad_fns)
param_groups = get_param_groups(inp_grad_fns, weight_grad_fns)
del reverse_graph_refs
for param_group in param_groups:
for i, intermediate in enumerate(param_group["intermediates"]):
def get_hook(param_group, i):
def hook(grad_inputs):
if param_group.get("grads", None) is None:
param_group["grads"] = [None] * len(param_group["intermediates"])
param_group["grads"][i] = grad_inputs
return hook
# These are always "split" nodes that we need to recompute, so
# save their inputs.
intermediate.register_prehook(get_hook(param_group, i))
dinputs = torch.autograd.grad((out,), inputs=tuple(inps), grad_outputs=(torch.ones_like(out),), retain_graph=True)
return dinputs, param_groups
def compute_grads_only_weights2(user_weights, param_groups):
all_dweights = dict()
for param_group in param_groups:
# TODO: Handle case where intermediate can have multiple outputs
intermediate_edges = tuple(GradientEdge(i, 0) for i in param_group["intermediates"])
weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"])
assert all(len(g) == 1 for g in param_group["grads"])
# [NEW!] Able to pass a GradientEdge to autograd.grad as output
# We do not need to retain_graph because... guarantee no overlap?
print("trying to execute: ", intermediate_edges, weights_edges)
dweights = torch.autograd.grad(intermediate_edges, weights_edges, grad_outputs=sum(param_group["grads"], tuple()))
for w, dw in zip(param_group["params"], dweights):
all_dweights[w] = dw
# return grads in the original order weights were provided in
out = []
for w in user_weights:
grad_acc = _get_grad_fn_or_grad_acc(w)
out.append(all_dweights[grad_acc])
return tuple(out)
```
</details>
```python
import torch.nn as nn
# Setup
mod1 = nn.Linear(10, 10)
mod2 = nn.Linear(10, 10)
a = torch.rand(10, requires_grad=True)
weights = tuple(mod1.parameters()) + tuple(mod2.parameters())
inps = (a,)
out = mod2(mod1(a))
class LoggingTensorMode(torch.utils._python_dispatch.TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
rs = func(*args, **kwargs)
print(f"{func.__module__}.{func.__name__}")
return rs
print(" -- SPLIT -- ")
# Compute gradients in two parts
with LoggingTensorMode():
print("PART 1")
dinputs, state = compute_grads_only_inputs2((out,), inps, weights)
print("PART 2")
dweights = compute_grads_only_weights2(weights, state)
out = mod2(mod1(a))
print(" -- REF -- ")
# Compare with reference
with LoggingTensorMode():
ref_all_gradients = torch.autograd.grad(out, inputs=tuple(inps) + weights, grad_outputs=(torch.ones_like(out),))
for actual, ref in zip(dinputs + dweights, ref_all_gradients):
print(torch.allclose(actual, ref))
```
<img width="598" alt="image" src="https://github.com/pytorch/pytorch/assets/13428986/3681b8a7-3ab4-4d1d-a836-abef6913e671">
```
PART 1
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.ones_like.default
V0603 10:17:21.590878 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1ee160> with grad_outputs: [f32[10]]
torch._ops.aten.view.default
V0603 10:17:21.591204 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1ee0d0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
V0603 10:17:21.591578 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x100d7ae50> with grad_outputs: [f32[1, 10]]
torch._ops.aten.view.default
V0603 10:17:21.591747 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1e4a60> with grad_outputs: [f32[10]]
torch._ops.aten.view.default
V0603 10:17:21.591834 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1e4bb0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
V0603 10:17:21.591922 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1e4a90> with grad_outputs: [f32[1, 10]]
torch._ops.aten.view.default
PART 2
trying to execute: (GradientEdge(node=<AddmmBackward0 object at 0x12a1e4bb0>, output_nr=0),) (GradientEdge(node=<AccumulateGrad object at 0x12a21b130>, output_nr=0), GradientEdge(node=<AccumulateGrad object at 0x12a21b7c0>, output_nr=0))
V0603 10:17:21.592223 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1e4bb0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
torch._ops.aten.t.default
torch._ops.aten.sum.dim_IntList
torch._ops.aten.view.default
V0603 10:17:21.592421 8300067520 torch/autograd/graph.py:751] Executing: <TBackward0 object at 0x12a1cad60> with grad_outputs: [f32[10, 10]]
torch._ops.aten.t.default
trying to execute: (GradientEdge(node=<AddmmBackward0 object at 0x12a1ee0d0>, output_nr=0),) (GradientEdge(node=<AccumulateGrad object at 0x12a1e41c0>, output_nr=0), GradientEdge(node=<AccumulateGrad object at 0x12a21b670>, output_nr=0))
V0603 10:17:21.593481 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1ee0d0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
torch._ops.aten.t.default
torch._ops.aten.sum.dim_IntList
torch._ops.aten.view.default
V0603 10:17:21.593750 8300067520 torch/autograd/graph.py:751] Executing: <TBackward0 object at 0x12a21b2b0> with grad_outputs: [f32[10, 10]]
torch._ops.aten.t.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127766
Approved by: https://github.com/albanD
|
||
|---|---|---|
| .. | ||
| ao/sparsity | ||
| autograd | ||
| backends/xeon | ||
| benchmark_utils | ||
| bottleneck_test | ||
| cpp | ||
| cpp_api_parity | ||
| cpp_extensions | ||
| custom_backend | ||
| custom_operator | ||
| distributed | ||
| distributions | ||
| dynamo | ||
| dynamo_expected_failures | ||
| dynamo_skips | ||
| edge | ||
| error_messages | ||
| expect | ||
| export | ||
| forward_backward_compatibility | ||
| functorch | ||
| fx | ||
| higher_order_ops | ||
| inductor | ||
| jit | ||
| jit_hooks | ||
| lazy | ||
| mobile | ||
| nn | ||
| onnx | ||
| optim | ||
| package | ||
| profiler | ||
| quantization | ||
| scripts | ||
| test_img | ||
| torch_np | ||
| typing | ||
| xpu | ||
| _test_bazel.py | ||
| allowlist_for_publicAPI.json | ||
| conftest.py | ||
| create_dummy_torchscript_model.py | ||
| delete.py | ||
| hi.py | ||
| HowToWriteTestsUsingFileCheck.md | ||
| linear.py | ||
| load_torchscript_model.py | ||
| minioptest_failures_dict.json | ||
| mkl_verbose.py | ||
| mkldnn_verbose.py | ||
| pytest_shard_custom.py | ||
| run_doctests.sh | ||
| run_test.py | ||
| simulate_nccl_errors.py | ||
| test_ao_sparsity.py | ||
| test_autocast.py | ||
| test_autograd.py | ||
| test_autograd_fallback.py | ||
| test_autoload.py | ||
| test_binary_ufuncs.py | ||
| test_bundled_images.py | ||
| test_bundled_inputs.py | ||
| test_ci_sanity_check_fail.py | ||
| test_comparison_utils.py | ||
| test_compile_benchmark_util.py | ||
| test_complex.py | ||
| test_content_store.py | ||
| test_cpp_api_parity.py | ||
| test_cpp_extensions_aot.py | ||
| test_cpp_extensions_jit.py | ||
| test_cpp_extensions_mtia_backend.py | ||
| test_cpp_extensions_open_device_registration.py | ||
| test_cpp_extensions_stream_and_event.py | ||
| test_cuda.py | ||
| test_cuda_expandable_segments.py | ||
| test_cuda_multigpu.py | ||
| test_cuda_nvml_based_avail.py | ||
| test_cuda_primary_ctx.py | ||
| test_cuda_sanitizer.py | ||
| test_cuda_trace.py | ||
| test_custom_ops.py | ||
| test_dataloader.py | ||
| test_datapipe.py | ||
| test_decomp.py | ||
| test_deploy.py | ||
| test_determination.py | ||
| test_dispatch.py | ||
| test_dlpack.py | ||
| test_dynamic_shapes.py | ||
| test_expanded_weights.py | ||
| test_fake_tensor.py | ||
| test_flop_counter.py | ||
| test_foreach.py | ||
| test_function_schema.py | ||
| test_functional_autograd_benchmark.py | ||
| test_functional_optim.py | ||
| test_functionalization.py | ||
| test_functionalization_of_rng_ops.py | ||
| test_futures.py | ||
| test_fx.py | ||
| test_fx_experimental.py | ||
| test_fx_passes.py | ||
| test_fx_reinplace_pass.py | ||
| test_hub.py | ||
| test_import_stats.py | ||
| test_indexing.py | ||
| test_itt.py | ||
| test_jit.py | ||
| test_jit_autocast.py | ||
| test_jit_disabled.py | ||
| test_jit_fuser.py | ||
| test_jit_fuser_legacy.py | ||
| test_jit_fuser_te.py | ||
| test_jit_legacy.py | ||
| test_jit_llga_fuser.py | ||
| test_jit_profiling.py | ||
| test_jit_simple.py | ||
| test_jit_string.py | ||
| test_jiterator.py | ||
| test_kernel_launch_checks.py | ||
| test_legacy_vmap.py | ||
| test_license.py | ||
| test_linalg.py | ||
| test_logging.py | ||
| test_masked.py | ||
| test_maskedtensor.py | ||
| test_matmul_cuda.py | ||
| test_meta.py | ||
| test_metal.py | ||
| test_mkl_verbose.py | ||
| test_mkldnn.py | ||
| test_mkldnn_fusion.py | ||
| test_mkldnn_verbose.py | ||
| test_mobile_optimizer.py | ||
| test_model_dump.py | ||
| test_model_exports_to_core_aten.py | ||
| test_module_tracker.py | ||
| test_modules.py | ||
| test_monitor.py | ||
| test_mps.py | ||
| test_multiprocessing.py | ||
| test_multiprocessing_spawn.py | ||
| test_namedtensor.py | ||
| test_namedtuple_return_api.py | ||
| test_native_functions.py | ||
| test_native_mha.py | ||
| test_nestedtensor.py | ||
| test_nn.py | ||
| test_nnapi.py | ||
| test_numba_integration.py | ||
| test_numpy_interop.py | ||
| test_openmp.py | ||
| test_ops.py | ||
| test_ops_fwd_gradients.py | ||
| test_ops_gradients.py | ||
| test_ops_jit.py | ||
| test_optim.py | ||
| test_out_dtype_op.py | ||
| test_overrides.py | ||
| test_package.py | ||
| test_per_overload_api.py | ||
| test_prims.py | ||
| test_proxy_tensor.py | ||
| test_pruning_op.py | ||
| test_public_bindings.py | ||
| test_python_dispatch.py | ||
| test_pytree.py | ||
| test_quantization.py | ||
| test_reductions.py | ||
| test_scatter_gather_ops.py | ||
| test_schema_check.py | ||
| test_segment_reductions.py | ||
| test_serialization.py | ||
| test_set_default_mobile_cpu_allocator.py | ||
| test_shape_ops.py | ||
| test_show_pickle.py | ||
| test_sort_and_select.py | ||
| test_sparse.py | ||
| test_sparse_csr.py | ||
| test_sparse_semi_structured.py | ||
| test_spectral_ops.py | ||
| test_stateless.py | ||
| test_static_runtime.py | ||
| test_subclass.py | ||
| test_sympy_utils.py | ||
| test_tensor_creation_ops.py | ||
| test_tensorboard.py | ||
| test_tensorexpr.py | ||
| test_tensorexpr_pybind.py | ||
| test_testing.py | ||
| test_throughput_benchmark.py | ||
| test_torch.py | ||
| test_transformers.py | ||
| test_type_hints.py | ||
| test_type_info.py | ||
| test_type_promotion.py | ||
| test_typing.py | ||
| test_unary_ufuncs.py | ||
| test_utils.py | ||
| test_view_ops.py | ||
| test_vulkan.py | ||
| test_weak.py | ||
| test_xnnpack_integration.py | ||
| test_xpu.py | ||