[Bart] Question Answering Model is added to tests (#5024)

* fix test

* Update tests/test_modeling_common.py

* Update tests/test_modeling_common.py
This commit is contained in:
Patrick von Platen 2020-06-15 22:50:09 +02:00 committed by GitHub
parent bbad4c6989
commit ebba39e4e1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 3 deletions

View file

@ -113,7 +113,9 @@ def prepare_bart_inputs_dict(
@require_torch
class BARTModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else ()
(BartModel, BartForConditionalGeneration, BartForSequenceClassification, BartForQuestionAnswering)
if is_torch_available()
else ()
)
all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
is_encoder_decoder = True

View file

@ -38,6 +38,7 @@ if is_torch_available():
BertConfig,
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
top_k_top_p_filtering,
)
@ -180,8 +181,13 @@ class ModelTesterMixin:
correct_outlen = 4
decoder_attention_idx = 1
if "lm_labels" in inputs_dict: # loss will come first
correct_outlen += 1 # compute loss
# loss is at first position
if "labels" in inputs_dict:
correct_outlen += 1 # loss is added to beginning
decoder_attention_idx += 1
# Question Answering model returns start_logits and end_logits
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
decoder_attention_idx += 1
self.assertEqual(out_len, correct_outlen)