mirror of
https://github.com/saymrwulf/transformers.git
synced 2026-05-14 20:58:08 +00:00
tests: fix pytorch tensor placement errors (#33485)
This commit fixes the following errors: * Fix "expected all tensors to be on the same device" error * Fix "can't convert device type tensor to numpy" According to pytorch documentation torch.Tensor.numpy(force=False) performs conversion only if tensor is on CPU (plus few other restrictions) which is not the case. For our case we need force=True since we just need a data and don't care about tensors coherency. Fixes: #33517 See: https://pytorch.org/docs/2.4/generated/torch.Tensor.numpy.html Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
This commit is contained in:
parent
52daf4ec76
commit
5e2916bc14
8 changed files with 29 additions and 26 deletions
|
|
@ -163,7 +163,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
|||
# numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision
|
||||
if v.dtype == bfloat16:
|
||||
v = v.float()
|
||||
pt_state_dict[k] = v.numpy()
|
||||
pt_state_dict[k] = v.cpu().numpy()
|
||||
|
||||
model_prefix = flax_model.base_model_prefix
|
||||
|
||||
|
|
|
|||
|
|
@ -848,6 +848,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
|||
with self.subTest(model_class.__name__):
|
||||
# load PyTorch class
|
||||
pt_model = model_class(config).eval()
|
||||
pt_model.to(torch_device)
|
||||
# Flax models don't use the `use_cache` option and cache is not returned as a default.
|
||||
# So we disable `use_cache` here for PyTorch model.
|
||||
pt_model.config.use_cache = False
|
||||
|
|
@ -881,7 +882,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
|||
fx_outputs = fx_model(**fx_inputs).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
|
|
@ -892,7 +893,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
|||
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)
|
||||
|
||||
# overwrite from common since FlaxCLIPModel returns nested output
|
||||
# which is not supported in the common test
|
||||
|
|
@ -921,6 +922,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
|||
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
|
||||
|
||||
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
|
||||
pt_model.to(torch_device)
|
||||
|
||||
# make sure weights are tied in PyTorch
|
||||
pt_model.tie_weights()
|
||||
|
|
@ -940,11 +942,12 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
|||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
|
||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
|
||||
pt_model_loaded.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||
|
|
@ -953,7 +956,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
|||
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
|
|
|
|||
|
|
@ -297,7 +297,7 @@ class FlaxEncoderDecoderMixin:
|
|||
|
||||
# prepare inputs
|
||||
flax_inputs = inputs_dict
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
|
||||
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
|
@ -305,7 +305,7 @@ class FlaxEncoderDecoderMixin:
|
|||
fx_outputs = fx_model(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5)
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5)
|
||||
|
||||
# PT -> Flax
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
|
|
@ -315,7 +315,7 @@ class FlaxEncoderDecoderMixin:
|
|||
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5)
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5)
|
||||
|
||||
# Flax -> PT
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
|
|
@ -330,7 +330,7 @@ class FlaxEncoderDecoderMixin:
|
|||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5)
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5)
|
||||
|
||||
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
|
||||
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
|
|
|
|||
|
|
@ -170,7 +170,7 @@ class InformerModelTester:
|
|||
|
||||
embed_positions = InformerSinusoidalPositionalEmbedding(
|
||||
config.context_length + config.prediction_length, config.d_model
|
||||
)
|
||||
).to(torch_device)
|
||||
self.parent.assertTrue(torch.equal(model.encoder.embed_positions.weight, embed_positions.weight))
|
||||
self.parent.assertTrue(torch.equal(model.decoder.embed_positions.weight, embed_positions.weight))
|
||||
|
||||
|
|
|
|||
|
|
@ -412,7 +412,7 @@ class FlaxEncoderDecoderMixin:
|
|||
|
||||
# prepare inputs
|
||||
flax_inputs = inputs_dict
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
|
||||
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
|
@ -420,7 +420,7 @@ class FlaxEncoderDecoderMixin:
|
|||
fx_outputs = fx_model(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5)
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5)
|
||||
|
||||
# PT -> Flax
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
|
|
@ -430,7 +430,7 @@ class FlaxEncoderDecoderMixin:
|
|||
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5)
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5)
|
||||
|
||||
# Flax -> PT
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
|
|
@ -445,7 +445,7 @@ class FlaxEncoderDecoderMixin:
|
|||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5)
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5)
|
||||
|
||||
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
|
||||
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
|
|
|
|||
|
|
@ -241,7 +241,7 @@ class FlaxEncoderDecoderMixin:
|
|||
|
||||
# prepare inputs
|
||||
flax_inputs = inputs_dict
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
|
||||
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
|
@ -249,7 +249,7 @@ class FlaxEncoderDecoderMixin:
|
|||
fx_outputs = fx_model(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5)
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5)
|
||||
|
||||
# PT -> Flax
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
|
|
@ -259,7 +259,7 @@ class FlaxEncoderDecoderMixin:
|
|||
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5)
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5)
|
||||
|
||||
# Flax -> PT
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
|
|
@ -274,7 +274,7 @@ class FlaxEncoderDecoderMixin:
|
|||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5)
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5)
|
||||
|
||||
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
|
||||
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||
|
|
|
|||
|
|
@ -160,7 +160,7 @@ class VisionTextDualEncoderMixin:
|
|||
|
||||
# prepare inputs
|
||||
flax_inputs = inputs_dict
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
|
||||
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
|
@ -168,7 +168,7 @@ class VisionTextDualEncoderMixin:
|
|||
fx_outputs = fx_model(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
|
||||
|
||||
# PT -> Flax
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
|
|
@ -178,7 +178,7 @@ class VisionTextDualEncoderMixin:
|
|||
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)
|
||||
|
||||
# Flax -> PT
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
|
|
@ -193,7 +193,7 @@ class VisionTextDualEncoderMixin:
|
|||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 4e-2)
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 4e-2)
|
||||
|
||||
def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):
|
||||
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
|
||||
|
|
|
|||
|
|
@ -179,7 +179,7 @@ class VisionTextDualEncoderMixin:
|
|||
# prepare inputs
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values}
|
||||
pt_inputs = inputs_dict
|
||||
flax_inputs = {k: v.numpy() for k, v in pt_inputs.items()}
|
||||
flax_inputs = {k: v.numpy(force=True) for k, v in pt_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
|
@ -187,7 +187,7 @@ class VisionTextDualEncoderMixin:
|
|||
fx_outputs = fx_model(**flax_inputs).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
|
||||
|
||||
# PT -> Flax
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
|
|
@ -197,7 +197,7 @@ class VisionTextDualEncoderMixin:
|
|||
fx_outputs_loaded = fx_model_loaded(**flax_inputs).to_tuple()
|
||||
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)
|
||||
|
||||
# Flax -> PT
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
|
|
@ -212,7 +212,7 @@ class VisionTextDualEncoderMixin:
|
|||
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 4e-2)
|
||||
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 4e-2)
|
||||
|
||||
def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):
|
||||
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)
|
||||
|
|
|
|||
Loading…
Reference in a new issue