diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index 286d0f0ea..21c51f971 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -1324,7 +1324,7 @@ class BartForSequenceClassification(PretrainedBartModel): # Prepend logits outputs = (logits,) + outputs[1:] # Add hidden states and attention if they are here if labels is not None: # prepend loss to output, - loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1)) + loss = F.cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1)) outputs = (loss,) + outputs return outputs diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 89e41c79a..559046f66 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -171,7 +171,7 @@ class BartHeadTests(unittest.TestCase): vocab_size = 99 - def test_lm_forward(self): + def _get_config_and_data(self, output_past=False): input_ids = torch.tensor( [ [71, 82, 18, 33, 46, 91, 2], @@ -191,9 +191,8 @@ class BartHeadTests(unittest.TestCase): dtype=torch.long, device=torch_device, ) - batch_size = input_ids.shape[0] - decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size) + batch_size = input_ids.shape[0] config = BartConfig( vocab_size=self.vocab_size, d_model=24, @@ -204,14 +203,25 @@ class BartHeadTests(unittest.TestCase): encoder_ffn_dim=32, decoder_ffn_dim=32, max_position_embeddings=48, + output_past=output_past, ) + return config, input_ids, batch_size + + def test_sequence_classification_forward(self): + config, input_ids, batch_size = self._get_config_and_data() + labels = _long_tensor([2] * batch_size).to(torch_device) model = BartForSequenceClassification(config) model.to(torch_device) - outputs = model.forward(input_ids=input_ids, decoder_input_ids=input_ids) - logits = outputs[0] + outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=labels) + logits = outputs[1] expected_shape = torch.Size((batch_size, config.num_labels)) self.assertEqual(logits.shape, expected_shape) + loss = outputs[0] + self.assertIsInstance(loss.item(), float) + def test_lm_forward(self): + config, input_ids, batch_size = self._get_config_and_data(output_past=False) + decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size) lm_model = BartForMaskedLM(config) lm_model.to(torch_device) loss, logits, enc_features = lm_model.forward(