mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
[Bart] Question Answering Model is added to tests (#5024)
* fix test * Update tests/test_modeling_common.py * Update tests/test_modeling_common.py
This commit is contained in:
parent
bbad4c6989
commit
ebba39e4e1
2 changed files with 11 additions and 3 deletions
|
|
@ -113,7 +113,9 @@ def prepare_bart_inputs_dict(
|
|||
@require_torch
|
||||
class BARTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(BartModel, BartForConditionalGeneration, BartForSequenceClassification) if is_torch_available() else ()
|
||||
(BartModel, BartForConditionalGeneration, BartForSequenceClassification, BartForQuestionAnswering)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
all_generative_model_classes = (BartForConditionalGeneration,) if is_torch_available() else ()
|
||||
is_encoder_decoder = True
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ if is_torch_available():
|
|||
BertConfig,
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
top_k_top_p_filtering,
|
||||
)
|
||||
|
||||
|
|
@ -180,8 +181,13 @@ class ModelTesterMixin:
|
|||
correct_outlen = 4
|
||||
decoder_attention_idx = 1
|
||||
|
||||
if "lm_labels" in inputs_dict: # loss will come first
|
||||
correct_outlen += 1 # compute loss
|
||||
# loss is at first position
|
||||
if "labels" in inputs_dict:
|
||||
correct_outlen += 1 # loss is added to beginning
|
||||
decoder_attention_idx += 1
|
||||
# Question Answering model returns start_logits and end_logits
|
||||
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
|
||||
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
|
||||
decoder_attention_idx += 1
|
||||
self.assertEqual(out_len, correct_outlen)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue