diff --git a/docs/source/model_doc/distilbert.rst b/docs/source/model_doc/distilbert.rst index f5b3727f5..67f27495b 100644 --- a/docs/source/model_doc/distilbert.rst +++ b/docs/source/model_doc/distilbert.rst @@ -75,6 +75,13 @@ DistilBertForSequenceClassification :members: +DistilBertForMultipleChoice +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DistilBertForMultipleChoice + :members: + + DistilBertForTokenClassification ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 5f95bb86d..3c4d568b6 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -271,6 +271,7 @@ if is_torch_available(): DistilBertPreTrainedModel, DistilBertForMaskedLM, DistilBertModel, + DistilBertForMultipleChoice, DistilBertForSequenceClassification, DistilBertForQuestionAnswering, DistilBertForTokenClassification, diff --git a/src/transformers/configuration_distilbert.py b/src/transformers/configuration_distilbert.py index 09f37d963..3f74e4f2c 100644 --- a/src/transformers/configuration_distilbert.py +++ b/src/transformers/configuration_distilbert.py @@ -75,7 +75,7 @@ class DistilBertConfig(PretrainedConfig): The dropout probabilities used in the question answering model :class:`~transformers.DistilBertForQuestionAnswering`. seq_classif_dropout (:obj:`float`, optional, defaults to 0.2): - The dropout probabilities used in the sequence classification model + The dropout probabilities used in the sequence classification and the multiple choice model :class:`~transformers.DistilBertForSequenceClassification`. Example:: diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index 23120d382..6cfee27d4 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -78,6 +78,7 @@ from .modeling_camembert import ( from .modeling_ctrl import CTRLLMHeadModel, CTRLModel from .modeling_distilbert import ( DistilBertForMaskedLM, + DistilBertForMultipleChoice, DistilBertForQuestionAnswering, DistilBertForSequenceClassification, DistilBertForTokenClassification, @@ -314,6 +315,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict( (LongformerConfig, LongformerForMultipleChoice), (RobertaConfig, RobertaForMultipleChoice), (BertConfig, BertForMultipleChoice), + (DistilBertConfig, DistilBertForMultipleChoice), (XLNetConfig, XLNetForMultipleChoice), (AlbertConfig, AlbertForMultipleChoice), ] diff --git a/src/transformers/modeling_distilbert.py b/src/transformers/modeling_distilbert.py index f802573d1..1d2c8b80b 100644 --- a/src/transformers/modeling_distilbert.py +++ b/src/transformers/modeling_distilbert.py @@ -864,3 +864,111 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel): outputs = (loss,) + outputs return outputs # (loss), scores, (hidden_states), (attentions) + + +@add_start_docstrings( + """DistilBert Model with a multiple choice classification head on top (a linear layer on top of + the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """, + DISTILBERT_START_DOCSTRING, +) +class DistilBertForMultipleChoice(DistilBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.distilbert = DistilBertModel(config) + self.pre_classifier = nn.Linear(config.dim, config.dim) + self.classifier = nn.Linear(config.dim, 1) + self.dropout = nn.Dropout(config.seq_classif_dropout) + + self.init_weights() + + @add_start_docstrings_to_callable(DISTILBERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)")) + def forward( + self, + input_ids=None, + attention_mask=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for computing the multiple choice classification loss. + Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension + of the input tensors. (see `input_ids` above) + + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided): + Classification loss. + classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`): + `num_choices` is the second dimension of the input tensors. (see `input_ids` above). + + Classification scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + + Examples:: + + from transformers import DistilBertTokenizer, DistilBertForMultipleChoice + import torch + + tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased') + model = DistilBertForMultipleChoice.from_pretrained('distilbert-base-cased') + + prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + choice0 = "It is eaten with a fork and a knife." + choice1 = "It is eaten while held in the hand." + labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1 + + encoding = tokenizer.batch_encode_plus([[prompt, choice0], [prompt, choice1]], return_tensors='pt', pad_to_max_length=True) + outputs = model(**{k: v.unsqueeze(0) for k,v in encoding.items()}, labels=labels) # batch size is 1 + + # the linear classifier still needs to be trained + loss, logits = outputs[:2] + """ + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.distilbert( + input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + ) + + hidden_state = outputs[0] # (bs * num_choices, seq_len, dim) + pooled_output = hidden_state[:, 0] # (bs * num_choices, dim) + pooled_output = self.pre_classifier(pooled_output) # (bs * num_choices, dim) + pooled_output = nn.ReLU()(pooled_output) # (bs * num_choices, dim) + pooled_output = self.dropout(pooled_output) # (bs * num_choices, dim) + logits = self.classifier(pooled_output) # (bs * num_choices, 1) + + reshaped_logits = logits.view(-1, num_choices) # (bs, num_choices) + + outputs = (reshaped_logits,) + outputs[1:] # add hidden states and attention if they are here + + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + outputs = (loss,) + outputs + + return outputs # (loss), reshaped_logits, (hidden_states), (attentions) diff --git a/tests/test_modeling_distilbert.py b/tests/test_modeling_distilbert.py index a5b9024ac..2043cac82 100644 --- a/tests/test_modeling_distilbert.py +++ b/tests/test_modeling_distilbert.py @@ -28,6 +28,7 @@ if is_torch_available(): DistilBertConfig, DistilBertModel, DistilBertForMaskedLM, + DistilBertForMultipleChoice, DistilBertForTokenClassification, DistilBertForQuestionAnswering, DistilBertForSequenceClassification, @@ -41,6 +42,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase): ( DistilBertModel, DistilBertForMaskedLM, + DistilBertForMultipleChoice, DistilBertForQuestionAnswering, DistilBertForSequenceClassification, DistilBertForTokenClassification, @@ -218,6 +220,25 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase): ) self.check_loss_output(result) + def create_and_check_distilbert_for_multiple_choice( + self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_choices = self.num_choices + model = DistilBertForMultipleChoice(config=config) + model.to(torch_device) + model.eval() + multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + loss, logits = model( + multiple_choice_inputs_ids, attention_mask=multiple_choice_input_mask, labels=choice_labels, + ) + result = { + "loss": loss, + "logits": logits, + } + self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices]) + self.check_loss_output(result) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() (config, input_ids, input_mask, sequence_labels, token_labels, choice_labels) = config_and_inputs @@ -251,6 +272,10 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_distilbert_for_token_classification(*config_and_inputs) + def test_for_multiple_choice(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_distilbert_for_multiple_choice(*config_and_inputs) + # @slow # def test_model_from_pretrained(self): # for model_name in DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: