mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary: Add experimental support for torch.nn.Module as input types. Before this change, we don't support module inputs but recently we saw some interesting use cases like gpt-fast https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py#L68 where we directly pass in a module input for different variants of the same models. Since we don't really care about non-param or non-buffer states in non strict mode, we don't care about those either and pretend they are like plain constants during tracing. We treat any module input like a nested container of tensor, and each time we will automatically register a pytree handler for these module types to flatten its state dict into a group of tensors. We will just inline any module method call during tracing like we did for `self` module in export_for_training. This will make input modules' behavior very similar to the training module in typical case, except that we don't record the inputs as parameter or buffers but rather just plain user inputs. Test Plan: buck run mode/opt caffe2/test:test_export -- -r test_module_input Differential Revision: D67680827 Pull Request resolved: https://github.com/pytorch/pytorch/pull/143925 Approved by: https://github.com/tugsbayasgalan |
||
|---|---|---|
| .. | ||
| __init__.py | ||
| opinfo_schema.py | ||
| random_dag.py | ||
| test_converter.py | ||
| test_cpp_serdes.py | ||
| test_db.py | ||
| test_draft_export.py | ||
| test_experimental.py | ||
| test_export.py | ||
| test_export_legacy.py | ||
| test_export_nonstrict.py | ||
| test_export_training_ir_to_run_decomp.py | ||
| test_functionalized_assertions.py | ||
| test_hop.py | ||
| test_lift_unlift.py | ||
| test_pass_infra.py | ||
| test_passes.py | ||
| test_retraceability.py | ||
| test_schema.py | ||
| test_serdes.py | ||
| test_serialize.py | ||
| test_sparse.py | ||
| test_swap.py | ||
| test_tools.py | ||
| test_torchbind.py | ||
| test_tree_utils.py | ||
| test_unflatten.py | ||
| test_unflatten_training_ir.py | ||
| test_verifier.py | ||
| testing.py | ||