mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Optimized set_initialized_submodules. (#35493)
This commit is contained in:
parent
7051c5fcc8
commit
568941bf11
1 changed files with 6 additions and 4 deletions
|
|
@ -565,13 +565,15 @@ def set_initialized_submodules(model, state_dict_keys):
|
|||
Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state
|
||||
dict.
|
||||
"""
|
||||
state_dict_keys = set(state_dict_keys)
|
||||
not_initialized_submodules = {}
|
||||
for module_name, module in model.named_modules():
|
||||
loaded_keys = {k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")}
|
||||
# When checking if the root module is loaded all state_dict_keys must be used.
|
||||
if module_name == "":
|
||||
loaded_keys = set(state_dict_keys)
|
||||
if loaded_keys.issuperset(module.state_dict()):
|
||||
# When checking if the root module is loaded there's no need to prepend module_name.
|
||||
module_keys = set(module.state_dict())
|
||||
else:
|
||||
module_keys = {f"{module_name}.{k}" for k in module.state_dict()}
|
||||
if module_keys.issubset(state_dict_keys):
|
||||
module._is_hf_initialized = True
|
||||
else:
|
||||
not_initialized_submodules[module_name] = module
|
||||
|
|
|
|||
Loading…
Reference in a new issue