diff --git a/tests/big_bird/test_modeling_flax_big_bird.py b/tests/big_bird/test_modeling_flax_big_bird.py index 834b71b30..594612931 100644 --- a/tests/big_bird/test_modeling_flax_big_bird.py +++ b/tests/big_bird/test_modeling_flax_big_bird.py @@ -190,3 +190,12 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase): for jitted_output, output in zip(jitted_outputs, outputs): self.assertEqual(jitted_output.shape, output.shape) + + # overwrite from common in order to skip the check on `attentions` + def check_outputs(self, fx_outputs, pt_outputs, model_class, names): + # `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version, + # an effort was done to return `attention_probs` (yet to be verified). + if type(names) == str and names.startswith("attentions"): + return + else: + super().check_outputs(fx_outputs, pt_outputs, model_class, names) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 0de5005fe..e37352b97 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -120,6 +120,7 @@ class FlaxModelTesterMixin: test_mismatched_shapes = True is_encoder_decoder = False test_head_masking = False + has_attentions = True def _prepare_for_class(self, inputs_dict, model_class): inputs_dict = copy.deepcopy(inputs_dict) @@ -168,6 +169,7 @@ class FlaxModelTesterMixin: dict_inputs = self._prepare_for_class(inputs_dict, model_class) check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + # (Copied from tests.test_modeling_common.ModelTesterMixin.check_outputs) def check_outputs(self, fx_outputs, pt_outputs, model_class, names): """ Args: @@ -204,8 +206,7 @@ class FlaxModelTesterMixin: pt_outputs[pt_nans] = 0 fx_outputs[pt_nans] = 0 - max_diff = np.amax(np.abs(fx_outputs - pt_outputs)) - self.assertLessEqual(max_diff, 1e-5) + self.assert_almost_equals(fx_outputs, pt_outputs, 1e-5) else: raise ValueError( f"`fx_outputs` should be a `tuple` or an instance of `jnp.ndarray`. Got {type(fx_outputs)} instead." @@ -222,6 +223,7 @@ class FlaxModelTesterMixin: # Output all for aggressive testing config.output_hidden_states = True + config.output_attentions = self.has_attentions # prepare inputs prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) @@ -274,7 +276,7 @@ class FlaxModelTesterMixin: # Output all for aggressive testing config.output_hidden_states = True - # Pure convolutional models have no attention + config.output_attentions = self.has_attentions # prepare inputs prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) @@ -314,6 +316,7 @@ class FlaxModelTesterMixin: # send pytorch model to the correct device pt_model_loaded.to(torch_device) + pt_model_loaded.eval() with torch.no_grad(): pt_outputs_loaded = pt_model_loaded(**pt_inputs)