mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
As called out in https://github.com/pytorch/pytorch/pull/137999, preserving signatures of multiple calls when buffer mutations are present was NYI. The main problem was that intermediate values of buffers were not tracked, so couldn't be propagated statefully between multiple calls (i.e., they would need to be explicitly passed around, defeating the unlifting needed for preserving signatures). This PR fixes this situation, by introducing module attributes that carry the necessary intermediate values of buffer mutations. In general, a buffer mutation can have several intermediate values it depends on recursively, even other buffers. So rather than tying an intermediate value with a particular buffer, we tie it with the submodules that create and read it. We install an attribute on all modules that create or read a particular intermediate value, sharing the same initial storage (i.e., initialized with the same empty tensor). For the module that creates this intermediate value, we copy the value into the corresponding attribute; and for the modules that read it, we read the corresponding attribute instead. Another complication that needed to be addressed was that a `run_decompositions` following an `export_for_training` was not preserving module call graphs, which is needed for unflattening and, in particular, used when remapping inputs. Fortunately some existing metadata already tracks provenance of nodes, which we could use to update a module call graph after functionalization / decomposition. Differential Revision: D64806175 Pull Request resolved: https://github.com/pytorch/pytorch/pull/138669 Approved by: https://github.com/tugsbayasgalan |
||
|---|---|---|
| .. | ||
| __init__.py | ||
| opinfo_schema.py | ||
| test_converter.py | ||
| test_db.py | ||
| test_experimental.py | ||
| test_export.py | ||
| test_export_nonstrict.py | ||
| test_export_training_ir_to_run_decomp.py | ||
| test_functionalized_assertions.py | ||
| test_hop.py | ||
| test_lift_unlift.py | ||
| test_pass_infra.py | ||
| test_passes.py | ||
| test_retraceability.py | ||
| test_schema.py | ||
| test_serdes.py | ||
| test_serialize.py | ||
| test_sparse.py | ||
| test_swap.py | ||
| test_tools.py | ||
| test_torchbind.py | ||
| test_tree_utils.py | ||
| test_unflatten.py | ||
| test_unflatten_training_ir.py | ||
| test_verifier.py | ||
| testing.py | ||