From 7bff0af0a4b09b9bfc3ade2532c2a3acdff95eda Mon Sep 17 00:00:00 2001 From: Harutaka Kawamura Date: Tue, 27 Oct 2020 23:37:04 +0900 Subject: [PATCH] Fix a bug for `CallbackHandler.callback_list` (#8052) * Fix callback_list * Add test Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> * Fix test Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> --- src/transformers/trainer_callback.py | 2 +- tests/test_trainer_callback.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 08277b6fd..398d5fd35 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -325,7 +325,7 @@ class CallbackHandler(TrainerCallback): @property def callback_list(self): - return "\n".join(self.callbacks) + return "\n".join(cb.__class__.__name__ for cb in self.callbacks) def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): return self.call_event("on_init_end", args, state, control) diff --git a/tests/test_trainer_callback.py b/tests/test_trainer_callback.py index 133c4e29f..cc21d2d57 100644 --- a/tests/test_trainer_callback.py +++ b/tests/test_trainer_callback.py @@ -221,3 +221,10 @@ class TrainerCallbackTest(unittest.TestCase): trainer.train() events = trainer.callback_handler.callbacks[-2].events self.assertEqual(events, self.get_expected_events(trainer)) + + # warning should be emitted for duplicated callbacks + with unittest.mock.patch("transformers.trainer_callback.logger.warn") as warn_mock: + trainer = self.get_trainer( + callbacks=[MyTestTrainerCallback, MyTestTrainerCallback], + ) + assert str(MyTestTrainerCallback) in warn_mock.call_args[0][0]