diff --git a/tests/models/clip/test_modeling_clip.py b/tests/models/clip/test_modeling_clip.py index d16241ab2..98a3f8383 100644 --- a/tests/models/clip/test_modeling_clip.py +++ b/tests/models/clip/test_modeling_clip.py @@ -611,7 +611,7 @@ class CLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): pt_outputs = pt_model(**pt_inputs).to_tuple() # convert inputs to Flax - fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} + fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)} fx_outputs = fx_model(**fx_inputs).to_tuple() self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]): @@ -669,7 +669,7 @@ class CLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): with torch.no_grad(): pt_outputs = pt_model(**pt_inputs).to_tuple() - fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} + fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)} fx_outputs = fx_model(**fx_inputs).to_tuple() self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") diff --git a/tests/models/clipseg/test_modeling_clipseg.py b/tests/models/clipseg/test_modeling_clipseg.py index b54861d8d..35dca117e 100644 --- a/tests/models/clipseg/test_modeling_clipseg.py +++ b/tests/models/clipseg/test_modeling_clipseg.py @@ -592,7 +592,7 @@ class CLIPSegModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) pt_outputs = pt_model(**pt_inputs).to_tuple() # convert inputs to Flax - fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} + fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)} fx_outputs = fx_model(**fx_inputs).to_tuple() self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]): @@ -650,7 +650,7 @@ class CLIPSegModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) with torch.no_grad(): pt_outputs = pt_model(**pt_inputs).to_tuple() - fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} + fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)} fx_outputs = fx_model(**fx_inputs).to_tuple() self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 20bbfbac3..614b29b16 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -875,7 +875,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi } # convert inputs to Flax - fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} + fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)} fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) fx_model.params = fx_state @@ -948,7 +948,7 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi } # convert inputs to Flax - fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} + fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)} pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) @@ -1805,7 +1805,7 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest. } # convert inputs to Flax - fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} + fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)} fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) fx_model.params = fx_state @@ -1878,7 +1878,7 @@ class WhisperEncoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest. } # convert inputs to Flax - fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} + fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)} pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 878e3c647..7e7ca36b4 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2206,7 +2206,7 @@ class ModelTesterMixin: } # convert inputs to Flax - fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} + fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)} fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) fx_model.params = fx_state @@ -2278,7 +2278,7 @@ class ModelTesterMixin: } # convert inputs to Flax - fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} + fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)} pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)