From 9e795eac8806eb1a473b4a4f8b9e38c2df56794e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 8 Feb 2021 16:04:28 +0300 Subject: [PATCH] fix bert2bert test (#10063) --- tests/test_trainer_seq2seq.py | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/tests/test_trainer_seq2seq.py b/tests/test_trainer_seq2seq.py index e91515ca2..7931ca844 100644 --- a/tests/test_trainer_seq2seq.py +++ b/tests/test_trainer_seq2seq.py @@ -24,15 +24,9 @@ if is_datasets_available(): class Seq2seqTrainerTester(TestCasePlus): @slow - @require_datasets @require_torch + @require_datasets def test_finetune_bert2bert(self): - """ - Currently fails with: - - ImportError: To be able to use this metric, you need to install the following dependencies['absl', 'nltk', 'rouge_score'] - """ - bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny") tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") @@ -47,8 +41,6 @@ class Seq2seqTrainerTester(TestCasePlus): train_dataset = train_dataset.select(range(32)) val_dataset = val_dataset.select(range(16)) - rouge = datasets.load_metric("rouge") - batch_size = 4 def _map_to_encoder_decoder_inputs(batch): @@ -78,15 +70,9 @@ class Seq2seqTrainerTester(TestCasePlus): pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True) - rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])[ - "rouge2" - ].mid + accuracy = sum([int(pred_str[i] == label_str[i]) for i in range(len(pred_str))]) / len(pred_str) - return { - "rouge2_precision": round(rouge_output.precision, 4), - "rouge2_recall": round(rouge_output.recall, 4), - "rouge2_fmeasure": round(rouge_output.fmeasure, 4), - } + return {"accuracy": accuracy} # map train dataset train_dataset = train_dataset.map(