Integration test for electra model (#10073)

This commit is contained in:
sandip 2021-02-09 02:12:25 +05:30 committed by GitHub
parent 781220acab
commit 263fac71a2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -344,3 +344,19 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
for model_name in ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = ElectraModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@require_torch
class ElectraModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_no_head_absolute_embedding(self):
model = ElectraForPreTraining.from_pretrained("google/electra-small-discriminator")
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
output = model(input_ids)[0]
expected_shape = torch.Size((1, 11))
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.tensor(
[[-8.9253, -4.0305, -3.9306, -3.8774, -4.1873, -4.1280, 0.9429, -4.1672, 0.9281, 0.0410, -3.4823]]
)
self.assertTrue(torch.allclose(output, expected_slice, atol=1e-4))