mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
Integration test for electra model (#10073)
This commit is contained in:
parent
781220acab
commit
263fac71a2
1 changed files with 16 additions and 0 deletions
|
|
@ -344,3 +344,19 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
for model_name in ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = ElectraModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@require_torch
|
||||
class ElectraModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_no_head_absolute_embedding(self):
|
||||
model = ElectraForPreTraining.from_pretrained("google/electra-small-discriminator")
|
||||
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
|
||||
output = model(input_ids)[0]
|
||||
expected_shape = torch.Size((1, 11))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
expected_slice = torch.tensor(
|
||||
[[-8.9253, -4.0305, -3.9306, -3.8774, -4.1873, -4.1280, 0.9429, -4.1672, 0.9281, 0.0410, -3.4823]]
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(output, expected_slice, atol=1e-4))
|
||||
|
|
|
|||
Loading…
Reference in a new issue