From aebca696afb511013d71fb6da5c162d945fb31d4 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Tue, 29 Mar 2022 17:51:48 +0200 Subject: [PATCH] Fix missing output_attentions in PT/Flax equivalence test (#16271) * fix - set output_attentions to True * Update tests/test_modeling_flax_common.py * update for has_attentions * overwrite check_outputs in FlaxBigBirdModelTest Co-authored-by: ydshieh Co-authored-by: Suraj Patil --- tests/big_bird/test_modeling_flax_big_bird.py | 9 +++++++++ tests/test_modeling_flax_common.py | 9 ++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/tests/big_bird/test_modeling_flax_big_bird.py b/tests/big_bird/test_modeling_flax_big_bird.py index 834b71b30..594612931 100644 --- a/tests/big_bird/test_modeling_flax_big_bird.py +++ b/tests/big_bird/test_modeling_flax_big_bird.py @@ -190,3 +190,12 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase): for jitted_output, output in zip(jitted_outputs, outputs): self.assertEqual(jitted_output.shape, output.shape) + + # overwrite from common in order to skip the check on `attentions` + def check_outputs(self, fx_outputs, pt_outputs, model_class, names): + # `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version, + # an effort was done to return `attention_probs` (yet to be verified). + if type(names) == str and names.startswith("attentions"): + return + else: + super().check_outputs(fx_outputs, pt_outputs, model_class, names) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 0de5005fe..e37352b97 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -120,6 +120,7 @@ class FlaxModelTesterMixin: test_mismatched_shapes = True is_encoder_decoder = False test_head_masking = False + has_attentions = True def _prepare_for_class(self, inputs_dict, model_class): inputs_dict = copy.deepcopy(inputs_dict) @@ -168,6 +169,7 @@ class FlaxModelTesterMixin: dict_inputs = self._prepare_for_class(inputs_dict, model_class) check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + # (Copied from tests.test_modeling_common.ModelTesterMixin.check_outputs) def check_outputs(self, fx_outputs, pt_outputs, model_class, names): """ Args: @@ -204,8 +206,7 @@ class FlaxModelTesterMixin: pt_outputs[pt_nans] = 0 fx_outputs[pt_nans] = 0 - max_diff = np.amax(np.abs(fx_outputs - pt_outputs)) - self.assertLessEqual(max_diff, 1e-5) + self.assert_almost_equals(fx_outputs, pt_outputs, 1e-5) else: raise ValueError( f"`fx_outputs` should be a `tuple` or an instance of `jnp.ndarray`. Got {type(fx_outputs)} instead." @@ -222,6 +223,7 @@ class FlaxModelTesterMixin: # Output all for aggressive testing config.output_hidden_states = True + config.output_attentions = self.has_attentions # prepare inputs prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) @@ -274,7 +276,7 @@ class FlaxModelTesterMixin: # Output all for aggressive testing config.output_hidden_states = True - # Pure convolutional models have no attention + config.output_attentions = self.has_attentions # prepare inputs prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) @@ -314,6 +316,7 @@ class FlaxModelTesterMixin: # send pytorch model to the correct device pt_model_loaded.to(torch_device) + pt_model_loaded.eval() with torch.no_grad(): pt_outputs_loaded = pt_model_loaded(**pt_inputs)