diff --git a/tests/test_modeling_electra.py b/tests/test_modeling_electra.py index 5e6d8baeb..601ab6b29 100644 --- a/tests/test_modeling_electra.py +++ b/tests/test_modeling_electra.py @@ -344,3 +344,19 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase): for model_name in ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: model = ElectraModel.from_pretrained(model_name) self.assertIsNotNone(model) + + +@require_torch +class ElectraModelIntegrationTest(unittest.TestCase): + @slow + def test_inference_no_head_absolute_embedding(self): + model = ElectraForPreTraining.from_pretrained("google/electra-small-discriminator") + input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]]) + output = model(input_ids)[0] + expected_shape = torch.Size((1, 11)) + self.assertEqual(output.shape, expected_shape) + expected_slice = torch.tensor( + [[-8.9253, -4.0305, -3.9306, -3.8774, -4.1873, -4.1280, 0.9429, -4.1672, 0.9281, 0.0410, -3.4823]] + ) + + self.assertTrue(torch.allclose(output, expected_slice, atol=1e-4))