pytorch/test
soulitzer 2eec02523b [autograd] Support GradientEdge as output for torch.autograd.grad (#127766)
This is useful for splitting grad to run in two parts while preserving intermediates:

<details>

<summary>
Click to see code
</summary>

```python
import collections
import weakref
from torch.autograd.graph import GradientEdge

def _get_grad_fn_or_grad_acc(t):
    if t.requires_grad and t.grad_fn is None:
        return t.view_as(t).grad_fn.next_functions[0][0]
    else:
        return t.grad_fn

def reverse_closure(roots, target_nodes):
    # Recurse until we reach a target node
    closure = set()
    actual_target_nodes = set()
    q: Deque = collections.deque()
    for node in roots:
        if node is not None and node not in closure:
            closure.add(node)
            q.append(node)
    while q:
        node = q.popleft()
        reverse_edges = node.metadata.get("reverse_edges", [])
        for holder_ref, idx in reverse_edges:
            ref = holder_ref()
            if ref is not None:
                raise RuntimeError("Reverse graph is no longer alive")
            fn = ref.node
            if fn in closure or fn is None:
                continue
            if fn in target_nodes:
                actual_target_nodes.add(fn)
                continue
            closure.add(fn)
            q.append(fn)
    return closure, actual_target_nodes

# Enable weak pointer
class Holder():
    def __init__(self, node):
        self.node = node

# TODO: use weak references to avoid reference cycle
def construct_reverse_graph(roots):
    q: Deque = collections.deque()
    root_seen = set()
    reverse_graph_refs = []
    for node in roots:
        if node is not None and node not in root_seen:
            q.append(node)
            root_seen.add(node)
    while q:
        node = q.popleft()
        for fn, idx in node.next_functions:
            if fn is not None:
                # Don't necessarily need to store on the graph
                reverse_edges = fn.metadata.get("reverse_edges", [])
                if len(reverse_edges) == 0:
                    q.append(fn)
                holder = Holder(node)
                holder_ref = weakref.ref(holder)
                reverse_graph_refs.append(holder)
                reverse_edges.append((holder_ref, idx))
                fn.metadata["reverse_edges"] = reverse_edges
    return reverse_graph_refs

def get_param_groups(inputs, params):
    inputs_closure, _ = reverse_closure(inputs, set())
    param_groups = dict()  # keyed on intermediates
    for i, param in enumerate(params):
        closure, intersected = reverse_closure([param], inputs_closure)
        param_group = {
            "params": set([param]),
            "intermediates": set(intersected),
        }
        for input_node in intersected:
            existing = param_groups.get(input_node, None)
            if existing is not None:
                existing["params"] = existing["params"].union(param_group["params"])
                existing["intermediates"] = existing["intermediates"].union(param_group["intermediates"])
                param_group = existing
            else:
                param_groups[input_node] = param_group

    # Sanity check: union of all param_groups params should be equal to all params
    union_params = set()
    seen_ids = set()
    unique_param_groups = []
    for param_group in param_groups.values():
        if id(param_group) not in seen_ids:
            seen_ids.add(id(param_group))
            unique_param_groups.append(param_group)
            union_params = union_params.union(param_group["params"])
    assert union_params == set(params)

    return unique_param_groups

def compute_grads_only_inputs2(roots, inps, weights):
    root_grad_fns = list(map(_get_grad_fn_or_grad_acc, roots))
    inp_grad_fns = list(map(_get_grad_fn_or_grad_acc, inps))
    weight_grad_fns = list(map(_get_grad_fn_or_grad_acc, weights))

    reverse_graph_refs = construct_reverse_graph(root_grad_fns)
    param_groups = get_param_groups(inp_grad_fns, weight_grad_fns)
    del reverse_graph_refs

    for param_group in param_groups:
        for i, intermediate in enumerate(param_group["intermediates"]):
            def get_hook(param_group, i):
                def hook(grad_inputs):
                    if param_group.get("grads", None) is None:
                        param_group["grads"] = [None] * len(param_group["intermediates"])
                    param_group["grads"][i] = grad_inputs
                return hook
            # These are always "split" nodes that we need to recompute, so
            # save their inputs.
            intermediate.register_prehook(get_hook(param_group, i))

    dinputs = torch.autograd.grad((out,), inputs=tuple(inps), grad_outputs=(torch.ones_like(out),), retain_graph=True)
    return dinputs, param_groups

def compute_grads_only_weights2(user_weights, param_groups):
    all_dweights = dict()
    for param_group in param_groups:
        # TODO: Handle case where intermediate can have multiple outputs
        intermediate_edges = tuple(GradientEdge(i, 0) for i in param_group["intermediates"])
        weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"])

        assert all(len(g) == 1 for g in param_group["grads"])
        # [NEW!] Able to pass a GradientEdge to autograd.grad as output
        # We do not need to retain_graph because... guarantee no overlap?
        print("trying to execute: ", intermediate_edges, weights_edges)
        dweights = torch.autograd.grad(intermediate_edges, weights_edges, grad_outputs=sum(param_group["grads"], tuple()))
        for w, dw in zip(param_group["params"], dweights):
            all_dweights[w] = dw
    # return grads in the original order weights were provided in
    out = []
    for w in user_weights:
        grad_acc = _get_grad_fn_or_grad_acc(w)
        out.append(all_dweights[grad_acc])
    return tuple(out)

```

</details>

```python
import torch.nn as nn

# Setup
mod1 = nn.Linear(10, 10)
mod2 = nn.Linear(10, 10)

a = torch.rand(10, requires_grad=True)

weights = tuple(mod1.parameters()) + tuple(mod2.parameters())
inps = (a,)

out = mod2(mod1(a))

class LoggingTensorMode(torch.utils._python_dispatch.TorchDispatchMode):
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        rs = func(*args, **kwargs)
        print(f"{func.__module__}.{func.__name__}")
        return rs

print(" -- SPLIT -- ")
# Compute gradients in two parts
with LoggingTensorMode():
    print("PART 1")
    dinputs, state = compute_grads_only_inputs2((out,), inps, weights)
    print("PART 2")
    dweights = compute_grads_only_weights2(weights, state)

out = mod2(mod1(a))

print(" -- REF -- ")

# Compare with reference
with LoggingTensorMode():
    ref_all_gradients = torch.autograd.grad(out, inputs=tuple(inps) + weights, grad_outputs=(torch.ones_like(out),))

for actual, ref in zip(dinputs + dweights, ref_all_gradients):
    print(torch.allclose(actual, ref))

```

<img width="598" alt="image" src="https://github.com/pytorch/pytorch/assets/13428986/3681b8a7-3ab4-4d1d-a836-abef6913e671">

```
PART 1
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.ones_like.default
V0603 10:17:21.590878 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1ee160> with grad_outputs: [f32[10]]
torch._ops.aten.view.default
V0603 10:17:21.591204 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1ee0d0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
V0603 10:17:21.591578 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x100d7ae50> with grad_outputs: [f32[1, 10]]
torch._ops.aten.view.default
V0603 10:17:21.591747 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1e4a60> with grad_outputs: [f32[10]]
torch._ops.aten.view.default
V0603 10:17:21.591834 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1e4bb0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
V0603 10:17:21.591922 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1e4a90> with grad_outputs: [f32[1, 10]]
torch._ops.aten.view.default
PART 2
trying to execute:  (GradientEdge(node=<AddmmBackward0 object at 0x12a1e4bb0>, output_nr=0),) (GradientEdge(node=<AccumulateGrad object at 0x12a21b130>, output_nr=0), GradientEdge(node=<AccumulateGrad object at 0x12a21b7c0>, output_nr=0))
V0603 10:17:21.592223 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1e4bb0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
torch._ops.aten.t.default
torch._ops.aten.sum.dim_IntList
torch._ops.aten.view.default
V0603 10:17:21.592421 8300067520 torch/autograd/graph.py:751] Executing: <TBackward0 object at 0x12a1cad60> with grad_outputs: [f32[10, 10]]
torch._ops.aten.t.default
trying to execute:  (GradientEdge(node=<AddmmBackward0 object at 0x12a1ee0d0>, output_nr=0),) (GradientEdge(node=<AccumulateGrad object at 0x12a1e41c0>, output_nr=0), GradientEdge(node=<AccumulateGrad object at 0x12a21b670>, output_nr=0))
V0603 10:17:21.593481 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1ee0d0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
torch._ops.aten.t.default
torch._ops.aten.sum.dim_IntList
torch._ops.aten.view.default
V0603 10:17:21.593750 8300067520 torch/autograd/graph.py:751] Executing: <TBackward0 object at 0x12a21b2b0> with grad_outputs: [f32[10, 10]]
torch._ops.aten.t.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127766
Approved by: https://github.com/albanD
2024-07-16 21:46:19 +00:00
..
ao/sparsity
autograd
backends/xeon
benchmark_utils
bottleneck_test
cpp [structural binding][10/N] Replace std::tie with structural binding (#130784) 2024-07-16 10:28:14 +00:00
cpp_api_parity
cpp_extensions [9/N] Replace c10::optional with std::optional (#130674) 2024-07-15 00:48:43 +00:00
custom_backend
custom_operator Revert "Tighten torch.library.infer_schema input types (#130705)" 2024-07-16 12:57:11 +00:00
distributed Revert "[Traceable FSDP2][Inductor] Re-inplace all_gather_into_tensor (#129773)" 2024-07-16 20:54:14 +00:00
distributions
dynamo Revert "Propagate buffer and parameter indices through AOT (#130393)" 2024-07-16 15:43:34 +00:00
dynamo_expected_failures [3.12, 3.13, dynamo] simplified construction for frame f_locals/localsplus (#129185) 2024-07-12 17:56:38 +00:00
dynamo_skips [3.12, 3.13, dynamo] simplified construction for frame f_locals/localsplus (#129185) 2024-07-12 17:56:38 +00:00
edge
error_messages [Ez][BE]: Enable new stable ruff rules (#129825) 2024-07-02 14:47:10 +00:00
expect [FX][export] strict DCE pass, check schema for node impurity (#130552) 2024-07-12 15:43:27 +00:00
export [Fix]: Convert operator that does specialization to its symbolic counterpart (#129578) 2024-07-16 17:19:57 +00:00
forward_backward_compatibility [Inductor][Quant] Change the schema of QLinear Binary (#129049) 2024-07-02 12:36:38 +00:00
functorch Revert "Fix names conflict when lifting (#129817)" 2024-07-15 22:08:45 +00:00
fx [BE][Easy] fix ruff rule needless-bool (SIM103) (#130206) 2024-07-14 08:17:52 +00:00
higher_order_ops [Inductor] Fix the High Order Op layout issue (#128275) 2024-06-15 00:33:21 +00:00
inductor [autograd] Support GradientEdge as output for torch.autograd.grad (#127766) 2024-07-16 21:46:19 +00:00
jit [BE]: Update flake8-comprehensions and enable C420 (#130699) 2024-07-16 13:47:49 +00:00
jit_hooks
lazy [BE][Easy] replace import pathlib with from pathlib import Path (#129426) 2024-06-30 01:36:07 +00:00
mobile Revert "[BE][Easy] use pathlib.Path instead of dirname / ".." / pardir (#129374)" 2024-06-29 00:47:15 +00:00
nn Fix max_pool2d decomposition for empty list and integer limits (#129106) 2024-06-24 22:19:42 +00:00
onnx Revert "[ONNX] Remove beartype usage (#130484)" 2024-07-16 18:41:51 +00:00
optim
package
profiler [autograd] Support GradientEdge as output for torch.autograd.grad (#127766) 2024-07-16 21:46:19 +00:00
quantization Rename generate_numeric_debug_handle to numeric_debugger (#130590) 2024-07-15 22:42:27 +00:00
scripts
test_img
torch_np [BE][Easy] apply autofix for ruff rules unnecessary-collection-call (C408): list() / tuple() / dict() (#130199) 2024-07-11 17:30:28 +00:00
typing added type hints for __contains__ (#129653) 2024-06-30 11:49:11 +00:00
xpu
_test_bazel.py
allowlist_for_publicAPI.json Evaluate symexprs on load path of cache not write (#128997) 2024-06-20 08:55:12 +00:00
conftest.py run_test: Unset cpp stacktraces after reruns (#129004) 2024-07-03 01:50:15 +00:00
create_dummy_torchscript_model.py
delete.py
hi.py
HowToWriteTestsUsingFileCheck.md
linear.py
load_torchscript_model.py
minioptest_failures_dict.json
mkl_verbose.py
mkldnn_verbose.py
pytest_shard_custom.py
run_doctests.sh
run_test.py [BE][Easy] fix ruff rule needless-bool (SIM103) (#130206) 2024-07-14 08:17:52 +00:00
simulate_nccl_errors.py
test_ao_sparsity.py
test_autocast.py Revert "[MPS] Add support for autocast in MPS (#99272)" 2024-07-02 12:29:51 +00:00
test_autograd.py [autograd] Support GradientEdge as output for torch.autograd.grad (#127766) 2024-07-16 21:46:19 +00:00
test_autograd_fallback.py
test_autoload.py [RFC] Add support for device extension autoloading (#127074) 2024-07-09 06:14:13 +00:00
test_binary_ufuncs.py [BE]: Update ruff to 0.5.0 (#129744) 2024-06-28 21:49:56 +00:00
test_bundled_images.py
test_bundled_inputs.py
test_ci_sanity_check_fail.py
test_comparison_utils.py
test_compile_benchmark_util.py
test_complex.py
test_content_store.py
test_cpp_api_parity.py
test_cpp_extensions_aot.py
test_cpp_extensions_jit.py
test_cpp_extensions_mtia_backend.py
test_cpp_extensions_open_device_registration.py
test_cpp_extensions_stream_and_event.py [RELAND] Add xpu to getAccelerator (#129205) 2024-07-04 10:26:52 +00:00
test_cuda.py Support for expandable segments with cuda graph trees (#128068) 2024-07-15 23:23:23 +00:00
test_cuda_expandable_segments.py Support for expandable segments with cuda graph trees (#128068) 2024-07-15 23:23:23 +00:00
test_cuda_multigpu.py
test_cuda_nvml_based_avail.py
test_cuda_primary_ctx.py
test_cuda_sanitizer.py
test_cuda_trace.py
test_custom_ops.py Revert "Tighten torch.library.infer_schema input types (#130705)" 2024-07-16 12:57:11 +00:00
test_dataloader.py [BE] enable UFMT for torch/storage.py (#127706) 2024-06-27 23:16:24 +00:00
test_datapipe.py
test_decomp.py Set seed per sample for OpInfo tests + support for restricting to a single sample input (#128238) 2024-07-08 16:06:38 +00:00
test_deploy.py
test_determination.py [Caffe2] [2/N] Remove Caffe2 from tests (#128911) 2024-06-19 00:05:50 +00:00
test_dispatch.py
test_dlpack.py
test_dynamic_shapes.py Make hashing a SymInt raise an error again (#130548) 2024-07-16 18:30:30 +00:00
test_expanded_weights.py
test_fake_tensor.py [BE] update type annotations for basic utilities in torch/__init__.py (#129001) 2024-06-24 18:04:38 +00:00
test_flop_counter.py [cuDNN][SDPA] Remove TORCH_CUDNN_SDPA_ENABLED=1, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 (#125343) 2024-06-30 19:22:16 +00:00
test_foreach.py Fix the rest of foreach flakers (#130277) 2024-07-09 02:08:21 +00:00
test_function_schema.py
test_functional_autograd_benchmark.py
test_functional_optim.py
test_functionalization.py
test_functionalization_of_rng_ops.py
test_futures.py
test_fx.py
test_fx_experimental.py [BE][Easy] apply autofix for ruff rules unnecessary-collection-call (C408): list() / tuple() / dict() (#130199) 2024-07-11 17:30:28 +00:00
test_fx_passes.py
test_fx_reinplace_pass.py
test_hub.py
test_import_stats.py
test_indexing.py Change index_put on GPU to accept FP8 inputs (#128758) 2024-06-25 00:38:03 +00:00
test_itt.py
test_jit.py [BE][Easy] fix ruff rule needless-bool (SIM103) (#130206) 2024-07-14 08:17:52 +00:00
test_jit_autocast.py
test_jit_disabled.py
test_jit_fuser.py
test_jit_fuser_legacy.py
test_jit_fuser_te.py
test_jit_legacy.py
test_jit_llga_fuser.py
test_jit_profiling.py
test_jit_simple.py
test_jit_string.py
test_jiterator.py
test_kernel_launch_checks.py
test_legacy_vmap.py
test_license.py
test_linalg.py fix torch.linalg.lstsq input check (#130612) 2024-07-12 23:06:52 +00:00
test_logging.py
test_masked.py
test_maskedtensor.py Set seed per sample for OpInfo tests + support for restricting to a single sample input (#128238) 2024-07-08 16:06:38 +00:00
test_matmul_cuda.py Updates to scaled_mm for rowwise scaling (#130059) 2024-07-04 00:53:17 +00:00
test_meta.py [dynamo] add meta fn for aten.kthvalue.default (#130562) 2024-07-12 23:48:31 +00:00
test_metal.py
test_mkl_verbose.py
test_mkldnn.py
test_mkldnn_fusion.py
test_mkldnn_verbose.py
test_mobile_optimizer.py
test_model_dump.py
test_model_exports_to_core_aten.py
test_module_tracker.py
test_modules.py [BE][Easy] fix ruff rule needless-bool (SIM103) (#130206) 2024-07-14 08:17:52 +00:00
test_monitor.py
test_mps.py Revert "[BE] bump optree version to 0.12.1 (#130139)" 2024-07-15 19:42:11 +00:00
test_multiprocessing.py Enable sharing meta tensors between processes (#129520) 2024-07-04 20:29:48 +00:00
test_multiprocessing_spawn.py
test_namedtensor.py
test_namedtuple_return_api.py sdp::SDPBackend::flash_attention support PrivateUse1 (#126392) 2024-06-28 17:48:40 +00:00
test_native_functions.py
test_native_mha.py
test_nestedtensor.py [NJT] throw an exception if nested_tensor_from_jagged is fx-traced without being fx.wrapped (#130702) 2024-07-16 19:21:10 +00:00
test_nn.py [BE][Easy] apply autofix for ruff rules unnecessary-collection-call (C408): list() / tuple() / dict() (#130199) 2024-07-11 17:30:28 +00:00
test_nnapi.py
test_numba_integration.py Add more dtypes to __cuda_array_interface__ (#129621) 2024-07-09 10:47:19 +00:00
test_numpy_interop.py
test_openmp.py [BE] enable UFMT for torch/storage.py (#127706) 2024-06-27 23:16:24 +00:00
test_ops.py [BE][Easy] fix ruff rule needless-bool (SIM103) (#130206) 2024-07-14 08:17:52 +00:00
test_ops_fwd_gradients.py
test_ops_gradients.py
test_ops_jit.py
test_optim.py Add testing regarding SparseAdam state_dicts (#130645) 2024-07-16 11:29:22 +00:00
test_out_dtype_op.py
test_overrides.py [BE][Easy] fix ruff rule needless-bool (SIM103) (#130206) 2024-07-14 08:17:52 +00:00
test_package.py
test_per_overload_api.py
test_prims.py Infer prim tags from equivalent aten ones (#130367) 2024-07-11 20:53:52 +00:00
test_proxy_tensor.py [dynamo] add meta fn for aten.kthvalue.default (#130562) 2024-07-12 23:48:31 +00:00
test_pruning_op.py
test_public_bindings.py Make public binding test only consider files that are packaged in the wheels (#130497) 2024-07-11 13:22:04 +00:00
test_python_dispatch.py
test_pytree.py
test_quantization.py Rename generate_numeric_debug_handle to numeric_debugger (#130590) 2024-07-15 22:42:27 +00:00
test_reductions.py Errors when 0-dim tensor of complex or bool type passed to aminmax. (#128404) 2024-06-24 21:46:49 +00:00
test_scatter_gather_ops.py
test_schema_check.py
test_segment_reductions.py
test_serialization.py Add torch.serialization.safe_globals context manager (#127939) 2024-07-12 20:38:43 +00:00
test_set_default_mobile_cpu_allocator.py
test_shape_ops.py
test_show_pickle.py
test_sort_and_select.py
test_sparse.py
test_sparse_csr.py [BE][Easy] apply autofix for ruff rules unnecessary-collection-call (C408): list() / tuple() / dict() (#130199) 2024-07-11 17:30:28 +00:00
test_sparse_semi_structured.py
test_spectral_ops.py
test_stateless.py Revert "[BE] bump optree version to 0.12.1 (#130139)" 2024-07-15 19:42:11 +00:00
test_static_runtime.py
test_subclass.py Only test _is_param if doing instance check on Parameter base (#130578) 2024-07-12 13:55:13 +00:00
test_sympy_utils.py Keep zero check be compatible with different sympy versions (#130729) 2024-07-16 08:39:00 +00:00
test_tensor_creation_ops.py Fix Storage.filename to not track the filename when storage was mmap-ed with MAP_PRIVATE (#128725) 2024-06-17 18:55:47 +00:00
test_tensorboard.py [dynamo][user-defined] Simplify and improve scope of UserDefinedObject var_getattr (#130169) 2024-07-08 04:10:56 +00:00
test_tensorexpr.py
test_tensorexpr_pybind.py
test_testing.py Revert "[BE] bump optree version to 0.12.1 (#130139)" 2024-07-15 19:42:11 +00:00
test_throughput_benchmark.py
test_torch.py Constant folding for dynamic shape node (#129686) 2024-07-16 00:17:11 +00:00
test_transformers.py [cpu][flash attention] fix nan issue (#130014) 2024-07-10 02:33:26 +00:00
test_type_hints.py Fix test test_type_hints.py::TestTypeHints::test_doc_examples (#129829) 2024-07-01 13:28:37 +00:00
test_type_info.py
test_type_promotion.py
test_typing.py Revert "[BE][Easy] use pathlib.Path instead of dirname / ".." / pardir (#129374)" 2024-06-29 00:47:15 +00:00
test_unary_ufuncs.py
test_utils.py
test_view_ops.py
test_vulkan.py
test_weak.py
test_xnnpack_integration.py Enable UFMT for numpy_test files, test_xnnpack_integration.py (#129023) 2024-06-28 05:40:31 +00:00
test_xpu.py Refine XPU UTs (#130138) 2024-07-05 09:56:22 +00:00