mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Resolve #137540 Summary: We might get different state_dict and named_parameters result when the module has registered custom state_dict_hooks. For exported_program's state_dict, we want the state_dict to reflect the actual module hierarchy at runtime, and it might be different from the model's state_dict() output if the model has state_dict hooks. To do weight swapping, one needs to either re-export or turn-off the hooks when saving model's state_dict(). Previously, ExportedProgram uses nn.Module's state_dict() method to populate its own state_dict, but it doesn't work for some models (e.g. llama3_3_vision) because ExportedProgram's state_dict and an nn.Module's state_dict have some subtle differences semantically. nn.Module's state_dict is about how the state should be serialized, and it reflects the structure of the original user model code. In contrast, export specializes on a “run” of a model, and its state_dict needs to reflect the runtime module hierarchy. One example where these two are different is TorchTune's Llama3_2_vision text decoder. Here, a FusionLayer is added as a local optimization and it is not part of the "static model definition". In runtime, we have mod.layers[3].layer.sa_norm.scale. But in nn.Module's state_dict, the authors of the model added a state_dict hook to remove the "layer" in mod.state_dict() to reflect the static model definition, so we have mod.state_dict()["layers.3.sa_norm.scale"]. In this Diff, we change ExportedProgram to populate its state_dict using named_parameters() and named_buffers() instead. So in ExportedProgram's state_dict, we have "layers.3.layer.sa_norm.scale", which reflects the runtime module hierarchy. Now one problem this presents is weight swapping. Since ExportedProgram's state and the model's state is not the same anymore, weight swapping procedure also needs to change slightly. In internal Ads and RecSys models deployment, weight swapping is where they have one model that is currently being being deployed and serving traffic, and they want to swap out the weights with newly trained model weights without having to redo the whole exporting/lowering process and create a new artifact. So they would move the deployed model’s pointer to the state dict over to the new state dict. Because of this, it’s previously a requirement that the FQNs are matching between the exported and the eager model’s state dict. The new ExportedProgram's state dict still supports weight swapping, but the state_dict to be swapped needs to be obtained from torch.export.exported_program instead of model.state_dict() if the model has state_dict hooks. The new requirement is that the FQNs are matching between the exported’s state dict and the state_dict obtained from `_disabled_load_state_dict_hooks(M)` context manager. One benefit of having this new API is that we are now in full control within export of gathering and updating the model state. If a model doesn't have any state_dict hooks, one can still use model.state_dict() for weight swapping, so it's BC. Test Plan: ``` buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export -- -r test_export_for_training_with_state_dict_hooks ``` Differential Revision: D64080561 Pull Request resolved: https://github.com/pytorch/pytorch/pull/137609 Approved by: https://github.com/angelayi, https://github.com/pianpwk |
||
|---|---|---|
| .. | ||
| db | ||
| pass_infra | ||
| passes | ||
| serde | ||
| __init__.py | ||
| converter.py | ||
| error.py | ||
| non_strict_utils.py | ||
| pass_base.py | ||
| tools.py | ||
| utils.py | ||
| verifier.py | ||
| wrappers.py | ||