mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
BartForSequenceClassification: fix num_labels, add test (#3110)
This commit is contained in:
parent
f631e01d2c
commit
e9e6efdc45
2 changed files with 16 additions and 6 deletions
|
|
@ -1324,7 +1324,7 @@ class BartForSequenceClassification(PretrainedBartModel):
|
|||
# Prepend logits
|
||||
outputs = (logits,) + outputs[1:] # Add hidden states and attention if they are here
|
||||
if labels is not None: # prepend loss to output,
|
||||
loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
loss = F.cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1))
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ class BartHeadTests(unittest.TestCase):
|
|||
|
||||
vocab_size = 99
|
||||
|
||||
def test_lm_forward(self):
|
||||
def _get_config_and_data(self, output_past=False):
|
||||
input_ids = torch.tensor(
|
||||
[
|
||||
[71, 82, 18, 33, 46, 91, 2],
|
||||
|
|
@ -191,9 +191,8 @@ class BartHeadTests(unittest.TestCase):
|
|||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
batch_size = input_ids.shape[0]
|
||||
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size)
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
config = BartConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
d_model=24,
|
||||
|
|
@ -204,14 +203,25 @@ class BartHeadTests(unittest.TestCase):
|
|||
encoder_ffn_dim=32,
|
||||
decoder_ffn_dim=32,
|
||||
max_position_embeddings=48,
|
||||
output_past=output_past,
|
||||
)
|
||||
return config, input_ids, batch_size
|
||||
|
||||
def test_sequence_classification_forward(self):
|
||||
config, input_ids, batch_size = self._get_config_and_data()
|
||||
labels = _long_tensor([2] * batch_size).to(torch_device)
|
||||
model = BartForSequenceClassification(config)
|
||||
model.to(torch_device)
|
||||
outputs = model.forward(input_ids=input_ids, decoder_input_ids=input_ids)
|
||||
logits = outputs[0]
|
||||
outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=labels)
|
||||
logits = outputs[1]
|
||||
expected_shape = torch.Size((batch_size, config.num_labels))
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
loss = outputs[0]
|
||||
self.assertIsInstance(loss.item(), float)
|
||||
|
||||
def test_lm_forward(self):
|
||||
config, input_ids, batch_size = self._get_config_and_data(output_past=False)
|
||||
decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size)
|
||||
lm_model = BartForMaskedLM(config)
|
||||
lm_model.to(torch_device)
|
||||
loss, logits, enc_features = lm_model.forward(
|
||||
|
|
|
|||
Loading…
Reference in a new issue