mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
[TF-PT-Tests] Fix PyTorch - TF tests for different GPU devices (#15846)
This commit is contained in:
parent
97f9b8a27b
commit
ddbb485c41
1 changed files with 2 additions and 3 deletions
|
|
@ -1493,9 +1493,8 @@ class ModelTesterMixin:
|
|||
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
|
||||
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model).to(torch_device)
|
||||
|
||||
# need to rename encoder-decoder "inputs" for PyTorch
|
||||
# if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
|
||||
# pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
|
||||
# Make sure PyTorch tensors are on same device as model
|
||||
pt_inputs = {k: v.to(torch_device) if torch.is_tensor(v) else v for k, v in pt_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pto = pt_model(**pt_inputs)
|
||||
|
|
|
|||
Loading…
Reference in a new issue