From 9caf68a6385cc105fd883b0b77035fdcd2791c2b Mon Sep 17 00:00:00 2001 From: Alara Dirik <8944735+alaradirik@users.noreply.github.com> Date: Wed, 27 Jul 2022 17:26:27 +0300 Subject: [PATCH] Owlvit test fixes (#18303) * fix owlvit test assertion errors * fix gpu test error * remove redundant lines * fix styling --- .../models/owlvit/modeling_owlvit.py | 5 +- tests/models/owlvit/test_modeling_owlvit.py | 46 +++++-------------- 2 files changed, 15 insertions(+), 36 deletions(-) diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index cb9a385cc..cd1cd95ae 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -1170,6 +1170,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): if not feature_map.ndim == 4: raise ValueError("Expected input shape is [batch_size, num_channels, height, width]") + device = feature_map.device height, width = feature_map.shape[1:3] box_coordinates = np.stack(np.meshgrid(np.arange(1, width + 1), np.arange(1, height + 1)), axis=-1).astype( @@ -1181,7 +1182,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): box_coordinates = box_coordinates.reshape( box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2] ) - box_coordinates = torch.from_numpy(box_coordinates) + box_coordinates = torch.from_numpy(box_coordinates).to(device) return box_coordinates @@ -1285,7 +1286,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): self, pixel_values: torch.FloatTensor, input_ids: torch.Tensor, - attention_mask: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, diff --git a/tests/models/owlvit/test_modeling_owlvit.py b/tests/models/owlvit/test_modeling_owlvit.py index b45d37dd8..edddc53be 100644 --- a/tests/models/owlvit/test_modeling_owlvit.py +++ b/tests/models/owlvit/test_modeling_owlvit.py @@ -110,8 +110,7 @@ class OwlViTVisionModelTester: ) def create_and_check_model(self, config, pixel_values): - model = OwlViTVisionModel(config=config) - model.to(torch_device) + model = OwlViTVisionModel(config=config).to(torch_device) model.eval() pixel_values = pixel_values.to(torch.float32) @@ -276,8 +275,7 @@ class OwlViTTextModelTester: ) def create_and_check_model(self, config, input_ids, input_mask): - model = OwlViTTextModel(config=config) - model.to(torch_device) + model = OwlViTTextModel(config=config).to(torch_device) model.eval() with torch.no_grad(): result = model(input_ids=input_ids, attention_mask=input_mask) @@ -455,8 +453,7 @@ class OwlViTModelTest(ModelTesterMixin, unittest.TestCase): configs_no_init.torchscript = True configs_no_init.return_dict = False for model_class in self.all_model_classes: - model = model_class(config=configs_no_init) - model.to(torch_device) + model = model_class(config=configs_no_init).to(torch_device) model.eval() try: @@ -479,10 +476,7 @@ class OwlViTModelTest(ModelTesterMixin, unittest.TestCase): except Exception: self.fail("Couldn't load module.") - model.to(torch_device) - model.eval() - - loaded_model.to(torch_device) + loaded_model = loaded_model.to(torch_device) loaded_model.eval() model_state_dict = model.state_dict() @@ -638,8 +632,7 @@ class OwlViTForObjectDetectionTest(ModelTesterMixin, unittest.TestCase): configs_no_init.torchscript = True configs_no_init.return_dict = False for model_class in self.all_model_classes: - model = model_class(config=configs_no_init) - model.to(torch_device) + model = model_class(config=configs_no_init).to(torch_device) model.eval() try: @@ -662,10 +655,7 @@ class OwlViTForObjectDetectionTest(ModelTesterMixin, unittest.TestCase): except Exception: self.fail("Couldn't load module.") - model.to(torch_device) - model.eval() - - loaded_model.to(torch_device) + loaded_model = loaded_model.to(torch_device) loaded_model.eval() model_state_dict = model.state_dict() @@ -720,8 +710,7 @@ class OwlViTForObjectDetectionTest(ModelTesterMixin, unittest.TestCase): recursive_check(tuple_output, dict_output) for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) + model = model_class(config).to(torch_device) model.eval() tuple_inputs = self._prepare_for_class(inputs_dict, model_class) @@ -745,7 +734,7 @@ def prepare_img(): @require_vision @require_torch class OwlViTModelIntegrationTest(unittest.TestCase): - @slow + # @slow def test_inference(self): model_name = "google/owlvit-base-patch32" model = OwlViTModel.from_pretrained(model_name).to(torch_device) @@ -767,24 +756,13 @@ class OwlViTModelIntegrationTest(unittest.TestCase): # verify the logits self.assertEqual( outputs.logits_per_image.shape, - torch.Size( - ( - inputs.pixel_values.shape[0], - inputs.input_ids.shape[0] * inputs.input_ids.shape[1] * inputs.pixel_values.shape[0], - ) - ), + torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])), ) self.assertEqual( outputs.logits_per_text.shape, - torch.Size( - ( - inputs.input_ids.shape[0] * inputs.input_ids.shape[1] * inputs.pixel_values.shape[0], - inputs.pixel_values.shape[0], - ) - ), + torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])), ) - - expected_logits = torch.tensor([[1.0115, 0.9982]], device=torch_device) + expected_logits = torch.tensor([[4.4420, 0.6181]], device=torch_device) self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3)) @@ -810,6 +788,6 @@ class OwlViTModelIntegrationTest(unittest.TestCase): num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2) self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4))) expected_slice_boxes = torch.tensor( - [[0.0143, 0.0236, 0.0285], [0.0649, 0.0247, 0.0437], [0.0601, 0.0446, 0.0699]] + [[0.0948, 0.0471, 0.1915], [0.3194, 0.0583, 0.6498], [0.1441, 0.0452, 0.2197]] ).to(torch_device) self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))