[TF-PT-Tests] Fix PyTorch - TF tests for different GPU devices (#15846)

This commit is contained in:
Patrick von Platen 2022-02-28 21:46:46 +01:00 committed by GitHub
parent 97f9b8a27b
commit ddbb485c41
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1493,9 +1493,8 @@ class ModelTesterMixin:
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model).to(torch_device)
# need to rename encoder-decoder "inputs" for PyTorch
# if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
# pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
# Make sure PyTorch tensors are on same device as model
pt_inputs = {k: v.to(torch_device) if torch.is_tensor(v) else v for k, v in pt_inputs.items()}
with torch.no_grad():
pto = pt_model(**pt_inputs)