tests: fix pytorch tensor placement errors (#33485)

This commit fixes the following errors:
* Fix "expected all tensors to be on the same device" error
* Fix "can't convert device type tensor to numpy"

According to pytorch documentation torch.Tensor.numpy(force=False)
performs conversion only if tensor is on CPU (plus few other restrictions)
which is not the case. For our case we need force=True since we just
need a data and don't care about tensors coherency.

Fixes: #33517
See: https://pytorch.org/docs/2.4/generated/torch.Tensor.numpy.html

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
This commit is contained in:
Dmitry Rogozhkin 2024-09-25 04:21:53 -07:00 committed by GitHub
parent 52daf4ec76
commit 5e2916bc14
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 29 additions and 26 deletions

View file

@ -163,7 +163,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
# numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision
if v.dtype == bfloat16:
v = v.float()
pt_state_dict[k] = v.numpy()
pt_state_dict[k] = v.cpu().numpy()
model_prefix = flax_model.base_model_prefix

View file

@ -848,6 +848,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
with self.subTest(model_class.__name__):
# load PyTorch class
pt_model = model_class(config).eval()
pt_model.to(torch_device)
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
@ -881,7 +882,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
fx_outputs = fx_model(**fx_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
@ -892,7 +893,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
)
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)
# overwrite from common since FlaxCLIPModel returns nested output
# which is not supported in the common test
@ -921,6 +922,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys()
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
pt_model.to(torch_device)
# make sure weights are tied in PyTorch
pt_model.tie_weights()
@ -940,11 +942,12 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True)
pt_model_loaded.to(torch_device)
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
@ -953,7 +956,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
)
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
@slow
def test_model_from_pretrained(self):

View file

@ -297,7 +297,7 @@ class FlaxEncoderDecoderMixin:
# prepare inputs
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
@ -305,7 +305,7 @@ class FlaxEncoderDecoderMixin:
fx_outputs = fx_model(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5)
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5)
# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
@ -315,7 +315,7 @@ class FlaxEncoderDecoderMixin:
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5)
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5)
# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
@ -330,7 +330,7 @@ class FlaxEncoderDecoderMixin:
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5)
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5)
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)

View file

@ -170,7 +170,7 @@ class InformerModelTester:
embed_positions = InformerSinusoidalPositionalEmbedding(
config.context_length + config.prediction_length, config.d_model
)
).to(torch_device)
self.parent.assertTrue(torch.equal(model.encoder.embed_positions.weight, embed_positions.weight))
self.parent.assertTrue(torch.equal(model.decoder.embed_positions.weight, embed_positions.weight))

View file

@ -412,7 +412,7 @@ class FlaxEncoderDecoderMixin:
# prepare inputs
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
@ -420,7 +420,7 @@ class FlaxEncoderDecoderMixin:
fx_outputs = fx_model(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5)
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5)
# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
@ -430,7 +430,7 @@ class FlaxEncoderDecoderMixin:
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5)
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5)
# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
@ -445,7 +445,7 @@ class FlaxEncoderDecoderMixin:
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5)
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5)
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
encoder_decoder_config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)

View file

@ -241,7 +241,7 @@ class FlaxEncoderDecoderMixin:
# prepare inputs
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
@ -249,7 +249,7 @@ class FlaxEncoderDecoderMixin:
fx_outputs = fx_model(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5)
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 1e-5)
# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
@ -259,7 +259,7 @@ class FlaxEncoderDecoderMixin:
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-5)
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 1e-5)
# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
@ -274,7 +274,7 @@ class FlaxEncoderDecoderMixin:
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 1e-5)
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 1e-5)
def check_equivalence_pt_to_flax(self, config, decoder_config, inputs_dict):
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)

View file

@ -160,7 +160,7 @@ class VisionTextDualEncoderMixin:
# prepare inputs
flax_inputs = inputs_dict
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()}
pt_inputs = {k: torch.tensor(v.tolist()).to(torch_device) for k, v in flax_inputs.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
@ -168,7 +168,7 @@ class VisionTextDualEncoderMixin:
fx_outputs = fx_model(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
@ -178,7 +178,7 @@ class VisionTextDualEncoderMixin:
fx_outputs_loaded = fx_model_loaded(**inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)
# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
@ -193,7 +193,7 @@ class VisionTextDualEncoderMixin:
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 4e-2)
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 4e-2)
def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)

View file

@ -179,7 +179,7 @@ class VisionTextDualEncoderMixin:
# prepare inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": pixel_values}
pt_inputs = inputs_dict
flax_inputs = {k: v.numpy() for k, v in pt_inputs.items()}
flax_inputs = {k: v.numpy(force=True) for k, v in pt_inputs.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
@ -187,7 +187,7 @@ class VisionTextDualEncoderMixin:
fx_outputs = fx_model(**flax_inputs).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output, pt_output.numpy(force=True), 4e-2)
# PT -> Flax
with tempfile.TemporaryDirectory() as tmpdirname:
@ -197,7 +197,7 @@ class VisionTextDualEncoderMixin:
fx_outputs_loaded = fx_model_loaded(**flax_inputs).to_tuple()
self.assertEqual(len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(force=True), 4e-2)
# Flax -> PT
with tempfile.TemporaryDirectory() as tmpdirname:
@ -212,7 +212,7 @@ class VisionTextDualEncoderMixin:
self.assertEqual(len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output_loaded in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(), 4e-2)
self.assert_almost_equals(fx_output, pt_output_loaded.numpy(force=True), 4e-2)
def check_equivalence_pt_to_flax(self, vision_config, text_config, inputs_dict):
config = VisionTextDualEncoderConfig.from_vision_text_configs(vision_config, text_config)