Optimized set_initialized_submodules. (#35493)

This commit is contained in:
v2ray 2025-01-22 00:01:28 +08:00 committed by GitHub
parent 7051c5fcc8
commit 568941bf11
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -565,13 +565,15 @@ def set_initialized_submodules(model, state_dict_keys):
Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state
dict.
"""
state_dict_keys = set(state_dict_keys)
not_initialized_submodules = {}
for module_name, module in model.named_modules():
loaded_keys = {k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")}
# When checking if the root module is loaded all state_dict_keys must be used.
if module_name == "":
loaded_keys = set(state_dict_keys)
if loaded_keys.issuperset(module.state_dict()):
# When checking if the root module is loaded there's no need to prepend module_name.
module_keys = set(module.state_dict())
else:
module_keys = {f"{module_name}.{k}" for k in module.state_dict()}
if module_keys.issubset(state_dict_keys):
module._is_hf_initialized = True
else:
not_initialized_submodules[module_name] = module