From ebba39e4e1d27e159d22d442e326a11cfbc10d31 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 15 Jun 2020 22:50:09 +0200 Subject: [PATCH] [Bart] Question Answering Model is added to tests (#5024) * fix test * Update tests/test_modeling_common.py * Update tests/test_modeling_common.py --- tests/test_modeling_bart.py | 4 +++- tests/test_modeling_common.py | 10 ++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index c48f20dc0..8ceee5e26 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -113,7 +113,9 @@ def prepare_bart_inputs_dict( @require_torch class BARTModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = ( - (BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else () + (BartModel, BartForConditionalGeneration, BartForSequenceClassification, BartForQuestionAnswering) + if is_torch_available() + else () ) all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else () is_encoder_decoder = True diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index ef805d91e..4fb0d53db 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -38,6 +38,7 @@ if is_torch_available(): BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_LIST, MODEL_FOR_MULTIPLE_CHOICE_MAPPING, + MODEL_FOR_QUESTION_ANSWERING_MAPPING, top_k_top_p_filtering, ) @@ -180,8 +181,13 @@ class ModelTesterMixin: correct_outlen = 4 decoder_attention_idx = 1 - if "lm_labels" in inputs_dict: # loss will come first - correct_outlen += 1 # compute loss + # loss is at first position + if "labels" in inputs_dict: + correct_outlen += 1 # loss is added to beginning + decoder_attention_idx += 1 + # Question Answering model returns start_logits and end_logits + if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values(): + correct_outlen += 1 # start_logits and end_logits instead of only 1 output decoder_attention_idx += 1 self.assertEqual(out_len, correct_outlen)