mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Previously we would error when trying to preserve the call signature for a module when it was called multiple times. This PR can now do this without erroring. The fix is to propagate call indices in a few more places. Note that while this works in the presence of params, buffers, and tensor constants, preserving call signatures for multiple calls to a module when buffers are mutated is not supported yet. This is future work. The main problem is that we do not have enough metadata to `copy_` mutated buffers at the end of each call to a module, so the next call can read those buffers at the beginning. Making this work will likely need some explicit tracking of intermediate values of mutated buffers when collecting metadata during functionalization in export. Note also that we stop short of creating a single graph out of multiple graphs: that is still future work. So the unflattened module will still have different targets `n`, `n@1`, `n@2`, etc. for each call when we ask the module call signature of `n` to be preserved. However it is way easier to swap all of these targets with a replacement that behaves similar to the original, because all of these calls will respect the original module call signature. (In particular, any constant inputs will be carried by the calls.) Differential Revision: D64406945 Pull Request resolved: https://github.com/pytorch/pytorch/pull/137999 Approved by: https://github.com/tugsbayasgalan |
||
|---|---|---|
| .. | ||
| __init__.py | ||
| _node_metadata_hook.py | ||
| add_runtime_assertions_for_constraints_pass.py | ||
| collect_tracepoints_pass.py | ||
| constant_folding.py | ||
| functionalize_side_effectful_ops_pass.py | ||
| lift_constants_pass.py | ||
| remove_runtime_assertions.py | ||
| replace_autocast_with_hop_pass.py | ||
| replace_quantized_ops_with_standard_ops_pass.py | ||
| replace_set_grad_with_hop_pass.py | ||
| replace_view_ops_with_view_copy_ops_pass.py | ||
| replace_with_hop_pass_util.py | ||