From ddbb485c41e63dfbd7c2667e01bbe2ab5b3fe660 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 28 Feb 2022 21:46:46 +0100 Subject: [PATCH] [TF-PT-Tests] Fix PyTorch - TF tests for different GPU devices (#15846) --- tests/test_modeling_common.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index b6ec0eae8..b18910d10 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1493,9 +1493,8 @@ class ModelTesterMixin: tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict) pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model).to(torch_device) - # need to rename encoder-decoder "inputs" for PyTorch - # if "inputs" in pt_inputs_dict and self.is_encoder_decoder: - # pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs") + # Make sure PyTorch tensors are on same device as model + pt_inputs = {k: v.to(torch_device) if torch.is_tensor(v) else v for k, v in pt_inputs.items()} with torch.no_grad(): pto = pt_model(**pt_inputs)