diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index dbdf45ac8..1af00b909 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -564,13 +564,34 @@ class ModelTesterMixin: model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() + non_persistent_buffers = {} + for key in loaded_model_state_dict.keys(): + if key not in model_state_dict.keys(): + non_persistent_buffers[key] = loaded_model_state_dict[key] + + loaded_model_state_dict = { + key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers + } + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + model_buffers = list(model.buffers()) + for non_persistent_buffer in non_persistent_buffers.values(): + found_buffer = False + for i, model_buffer in enumerate(model_buffers): + if torch.equal(non_persistent_buffer, model_buffer): + found_buffer = True + break + + self.assertTrue(found_buffer) + model_buffers.pop(i) + models_equal = True for layer_name, p1 in model_state_dict.items(): - p2 = loaded_model_state_dict[layer_name] - if p1.data.ne(p2.data).sum() > 0: - models_equal = False + if layer_name in loaded_model_state_dict: + p2 = loaded_model_state_dict[layer_name] + if p1.data.ne(p2.data).sum() > 0: + models_equal = False self.assertTrue(models_equal)