diff --git a/tests/models/camembert/test_modeling_camembert.py b/tests/models/camembert/test_modeling_camembert.py index 3a40f6a87..a15ab8caa 100644 --- a/tests/models/camembert/test_modeling_camembert.py +++ b/tests/models/camembert/test_modeling_camembert.py @@ -39,7 +39,8 @@ class CamembertModelIntegrationTest(unittest.TestCase): device=torch_device, dtype=torch.long, ) # J'aime le camembert ! - output = model(input_ids)["last_hidden_state"] + with torch.no_grad(): + output = model(input_ids)["last_hidden_state"] expected_shape = torch.Size((1, 10, 768)) self.assertEqual(output.shape, expected_shape) # compare the actual values for a slice.