diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index c8624124e..cca7bbb56 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2360,7 +2360,7 @@ class ModelTesterMixin: else: expected_missing = set() for pattern in model_reloaded._keys_to_ignore_on_load_missing: - expected_missing.update({k for k in model_reloaded.state_dict().keys() if re.search(pattern, k) is not None}) + expected_missing.update({k for k in param_names if re.search(pattern, k) is not None}) self.assertEqual( missed_missing, expected_missing,