This commit is contained in:
Cyril Vallez 2025-02-07 11:39:28 +01:00
parent f88bb46428
commit c95dc4ebe7
No known key found for this signature in database
2 changed files with 7 additions and 5 deletions

View file

@ -1476,12 +1476,12 @@ def _find_missing_and_unexpected_keys(
# Model-specific exceptions for missing and unexpected keys (e.g. if the modeling change over time, or any other reason...)
if cls._keys_to_ignore_on_load_missing is not None:
for pat in cls._keys_to_ignore_on_load_missing:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
for pattern in cls._keys_to_ignore_on_load_missing:
missing_keys = [k for k in missing_keys if re.search(pattern, k) is None]
if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
for pattern in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pattern, k) is None]
return missing_keys, unexpected_keys

View file

@ -2358,7 +2358,9 @@ class ModelTesterMixin:
if model_reloaded._keys_to_ignore_on_load_missing is None:
expected_missing = set()
else:
expected_missing = set(model_reloaded._keys_to_ignore_on_load_missing)
expected_missing = set()
for pattern in model_reloaded._keys_to_ignore_on_load_missing:
expected_missing.update({k for k in model_reloaded.state_dict().keys() if re.search(pattern, k) is not None})
self.assertEqual(
missed_missing,
expected_missing,