diff --git a/tests/test_modeling_lxmert.py b/tests/test_modeling_lxmert.py index 5484f2b71..56b68b92e 100644 --- a/tests/test_modeling_lxmert.py +++ b/tests/test_modeling_lxmert.py @@ -120,8 +120,8 @@ class LxmertModelTester: output_attentions = self.output_attentions input_ids = ids_tensor([self.batch_size, self.seq_length], vocab_size=self.vocab_size) - visual_feats = torch.rand(self.batch_size, self.num_visual_features, self.visual_feat_dim) - bounding_boxes = torch.rand(self.batch_size, self.num_visual_features, 4) + visual_feats = torch.rand(self.batch_size, self.num_visual_features, self.visual_feat_dim, device=torch_device) + bounding_boxes = torch.rand(self.batch_size, self.num_visual_features, 4, device=torch_device) input_mask = None if self.use_lang_mask: @@ -407,8 +407,8 @@ class LxmertModelTester: num_small_labels = int(config.num_qa_labels * 2) less_labels_ans = ids_tensor([self.batch_size], num_small_labels) more_labels_ans = ids_tensor([self.batch_size], num_large_labels) - model_pretrain = LxmertForPreTraining(config=config) - model_qa = LxmertForQuestionAnswering(config=config) + model_pretrain = LxmertForPreTraining(config=config).to(torch_device) + model_qa = LxmertForQuestionAnswering(config=config).to(torch_device) config.num_labels = num_small_labels end_labels = config.num_labels @@ -560,6 +560,7 @@ class LxmertModelTest(ModelTesterMixin, unittest.TestCase): def test_model_from_pretrained(self): for model_name in LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: model = LxmertModel.from_pretrained(model_name) + model.to(torch_device) self.assertIsNotNone(model) def test_attention_outputs(self):